Compare commits

...

1 Commits

Author SHA1 Message Date
Joan Fontanals Martinez a898c8028b feat: pass DBConfig args to HNSW 2023-06-16 12:40:04 +02:00
5 changed files with 61 additions and 32 deletions

View File

@ -26,7 +26,8 @@ def docs_to_index():
@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) @pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
def test_hnswlib_vectordb_batch(docs_to_index, replicas, shards, protocol, tmpdir): def test_hnswlib_vectordb_batch(docs_to_index, replicas, shards, protocol, tmpdir):
query = docs_to_index[:10] query = docs_to_index[:10]
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol) as db: with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol,
uses_with={'ef': 5000}) as db:
db.index(inputs=docs_to_index) db.index(inputs=docs_to_index)
if replicas > 1: if replicas > 1:
time.sleep(2) time.sleep(2)
@ -45,11 +46,12 @@ def test_hnswlib_vectordb_batch(docs_to_index, replicas, shards, protocol, tmpdi
@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) @pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
def test_hnswlib_vectordb_single_query(docs_to_index, limit, replicas, shards, protocol, tmpdir): def test_hnswlib_vectordb_single_query(docs_to_index, limit, replicas, shards, protocol, tmpdir):
query = docs_to_index[100] query = docs_to_index[100]
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol) as db: with HNSWVectorDB[MyDoc](ef=5000).serve(workspace=str(tmpdir), replicas=replicas, shards=shards,
protocol=protocol) as db:
db.index(inputs=docs_to_index) db.index(inputs=docs_to_index)
if replicas > 1: if replicas > 1:
time.sleep(2) time.sleep(2)
resp = db.search(inputs=query) resp = db.search(inputs=query, limit=limit)
assert len(resp.matches) == min(limit * shards, len(docs_to_index)) assert len(resp.matches) == min(limit * shards, len(docs_to_index))
assert resp.id == resp.matches[0].id assert resp.id == resp.matches[0].id
assert resp.text == resp.matches[0].text assert resp.text == resp.matches[0].text
@ -62,7 +64,8 @@ def test_hnswlib_vectordb_single_query(docs_to_index, limit, replicas, shards, p
def test_hnswlib_vectordb_delete(docs_to_index, replicas, shards, protocol, tmpdir): def test_hnswlib_vectordb_delete(docs_to_index, replicas, shards, protocol, tmpdir):
query = docs_to_index[0] query = docs_to_index[0]
delete = MyDoc(id=query.id, text='', embedding=np.random.rand(128)) delete = MyDoc(id=query.id, text='', embedding=np.random.rand(128))
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol) as db: with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol,
uses_with={'ef': 5000}) as db:
db.index(inputs=docs_to_index) db.index(inputs=docs_to_index)
if replicas > 1: if replicas > 1:
time.sleep(2) time.sleep(2)
@ -89,7 +92,8 @@ def test_hnswlib_vectordb_delete(docs_to_index, replicas, shards, protocol, tmpd
def test_hnswlib_vectordb_udpate_text(docs_to_index, replicas, shards, protocol, tmpdir): def test_hnswlib_vectordb_udpate_text(docs_to_index, replicas, shards, protocol, tmpdir):
query = docs_to_index[0] query = docs_to_index[0]
update = MyDoc(id=query.id, text=query.text + '_changed', embedding=query.embedding) update = MyDoc(id=query.id, text=query.text + '_changed', embedding=query.embedding)
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol) as db: with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol,
uses_with={'ef': 5000}) as db:
db.index(inputs=docs_to_index) db.index(inputs=docs_to_index)
if replicas > 1: if replicas > 1:
time.sleep(2) time.sleep(2)
@ -115,8 +119,8 @@ def test_hnswlib_vectordb_udpate_text(docs_to_index, replicas, shards, protocol,
def test_hnswlib_vectordb_restore(docs_to_index, replicas, shards, protocol, tmpdir): def test_hnswlib_vectordb_restore(docs_to_index, replicas, shards, protocol, tmpdir):
query = docs_to_index[:100] query = docs_to_index[:100]
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, with HNSWVectorDB[MyDoc](ef=5000).serve(workspace=str(tmpdir), replicas=replicas, shards=shards,
protocol=protocol) as db: protocol=protocol) as db:
db.index(docs=docs_to_index) db.index(docs=docs_to_index)
if replicas > 1: if replicas > 1:
time.sleep(2) time.sleep(2)
@ -129,7 +133,7 @@ def test_hnswlib_vectordb_restore(docs_to_index, replicas, shards, protocol, tmp
assert res.scores[0] < 0.001 # some precision issues, should be 0.0 assert res.scores[0] < 0.001 # some precision issues, should be 0.0
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards,
protocol=protocol) as new_db: protocol=protocol, uses_with={'ef': 5000}) as new_db:
time.sleep(2) time.sleep(2)
resp = new_db.search(docs=query) resp = new_db.search(docs=query)
assert len(resp) == len(query) assert len(resp) == len(query)

View File

@ -22,7 +22,7 @@ def docs_to_index():
@pytest.mark.parametrize('call_method', ['docs', 'inputs', 'positional']) @pytest.mark.parametrize('call_method', ['docs', 'inputs', 'positional'])
def test_hnswlib_vectordb_batch(docs_to_index, call_method, tmpdir): def test_hnswlib_vectordb_batch(docs_to_index, call_method, tmpdir):
query = docs_to_index[:10] query = docs_to_index[:10]
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
if call_method == 'docs': if call_method == 'docs':
indexer.index(docs=docs_to_index) indexer.index(docs=docs_to_index)
resp = indexer.search(docs=query) resp = indexer.search(docs=query)
@ -44,7 +44,7 @@ def test_hnswlib_vectordb_batch(docs_to_index, call_method, tmpdir):
@pytest.mark.parametrize('call_method', ['docs', 'inputs', 'positional']) @pytest.mark.parametrize('call_method', ['docs', 'inputs', 'positional'])
def test_hnswlib_vectordb_single_query(docs_to_index, limit, call_method, tmpdir): def test_hnswlib_vectordb_single_query(docs_to_index, limit, call_method, tmpdir):
query = docs_to_index[100] query = docs_to_index[100]
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
if call_method == 'docs': if call_method == 'docs':
indexer.index(docs=docs_to_index) indexer.index(docs=docs_to_index)
resp = indexer.search(docs=query, limit=limit) resp = indexer.search(docs=query, limit=limit)
@ -70,7 +70,7 @@ def test_hnswlib_vectordb_search_field(tmpdir):
for _ in range(2000)]) for _ in range(2000)])
query = docs_to_index[100] query = docs_to_index[100]
indexer = HNSWVectorDB[MyDocTens](workspace=str(tmpdir)) indexer = HNSWVectorDB[MyDocTens](workspace=str(tmpdir), ef=5000)
indexer.index(docs=docs_to_index) indexer.index(docs=docs_to_index)
resp = indexer.search(docs=query, search_field='tens') resp = indexer.search(docs=query, search_field='tens')
assert len(resp.matches) == 10 assert len(resp.matches) == 10
@ -83,7 +83,7 @@ def test_hnswlib_vectordb_search_field(tmpdir):
def test_hnswlib_vectordb_delete(docs_to_index, call_method, tmpdir): def test_hnswlib_vectordb_delete(docs_to_index, call_method, tmpdir):
query = docs_to_index[0] query = docs_to_index[0]
delete = MyDoc(id=query.id, text='', embedding=np.random.rand(128)) delete = MyDoc(id=query.id, text='', embedding=np.random.rand(128))
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
if call_method == 'docs': if call_method == 'docs':
indexer.index(docs=docs_to_index) indexer.index(docs=docs_to_index)
resp = indexer.search(docs=query) resp = indexer.search(docs=query)
@ -118,7 +118,7 @@ def test_hnswlib_vectordb_delete(docs_to_index, call_method, tmpdir):
def test_hnswlib_vectordb_udpate_text(docs_to_index, call_method, tmpdir): def test_hnswlib_vectordb_udpate_text(docs_to_index, call_method, tmpdir):
query = docs_to_index[0] query = docs_to_index[0]
update = MyDoc(id=query.id, text=query.text + '_changed', embedding=query.embedding) update = MyDoc(id=query.id, text=query.text + '_changed', embedding=query.embedding)
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
if call_method == 'docs': if call_method == 'docs':
indexer.index(docs=docs_to_index) indexer.index(docs=docs_to_index)
resp = indexer.search(docs=query) resp = indexer.search(docs=query)
@ -152,7 +152,7 @@ def test_hnswlib_vectordb_udpate_text(docs_to_index, call_method, tmpdir):
def test_hnswlib_vectordb_restore(docs_to_index, tmpdir): def test_hnswlib_vectordb_restore(docs_to_index, tmpdir):
query = docs_to_index[:100] query = docs_to_index[:100]
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
indexer.index(docs=docs_to_index) indexer.index(docs=docs_to_index)
resp = indexer.search(docs=query) resp = indexer.search(docs=query)
assert len(resp) == len(query) assert len(resp) == len(query)
@ -162,7 +162,7 @@ def test_hnswlib_vectordb_restore(docs_to_index, tmpdir):
# assert res.text == res.matches[0].text # assert res.text == res.matches[0].text
# assert res.scores[0] < 0.001 # some precision issues, should be 0 # assert res.scores[0] < 0.001 # some precision issues, should be 0
indexer.persist() indexer.persist()
new_indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir)) new_indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
resp = new_indexer.search(docs=query) resp = new_indexer.search(docs=query)
assert len(resp) == len(query) assert len(resp) == len(query)
for res in resp: for res in resp:

View File

@ -47,6 +47,7 @@ class VectorDB(Generic[TSchema]):
return VectorDBTyped return VectorDBTyped
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._workspace = None
if 'work_dir' in kwargs: if 'work_dir' in kwargs:
self._workspace = kwargs['work_dir'] self._workspace = kwargs['work_dir']
if 'workspace' in kwargs: if 'workspace' in kwargs:
@ -66,13 +67,16 @@ class VectorDB(Generic[TSchema]):
shards: Optional[int] = None, shards: Optional[int] = None,
replicas: Optional[int] = None, replicas: Optional[int] = None,
peer_ports: Optional[Union[Dict[str, List], List]] = None, peer_ports: Optional[Union[Dict[str, List], List]] = None,
uses_with: Optional[Dict] = None,
definition_file: Optional[str] = None, definition_file: Optional[str] = None,
obj_name: Optional[str] = None, obj_name: Optional[str] = None,
**kwargs): **kwargs):
from jina import Deployment, Flow from jina import Deployment, Flow
is_instance = False is_instance = False
uses_with = uses_with or {}
if isinstance(cls, VectorDB): if isinstance(cls, VectorDB):
is_instance = True is_instance = True
uses_with = uses_with.update(**cls._uses_with)
if is_instance: if is_instance:
workspace = workspace or cls._workspace workspace = workspace or cls._workspace
@ -134,6 +138,7 @@ class VectorDB(Generic[TSchema]):
if use_deployment: if use_deployment:
jina_object = Deployment(name='indexer', jina_object = Deployment(name='indexer',
uses=uses, uses=uses,
uses_with=uses_with,
port=port, port=port,
protocol=protocol, protocol=protocol,
shards=shards, shards=shards,
@ -145,6 +150,7 @@ class VectorDB(Generic[TSchema]):
else: else:
jina_object = Flow(port=port, protocol=protocol, **kwargs).add(name='indexer', jina_object = Flow(port=port, protocol=protocol, **kwargs).add(name='indexer',
uses=uses, uses=uses,
uses_with=uses_with,
shards=shards, shards=shards,
replicas=replicas, replicas=replicas,
stateful=stateful, stateful=stateful,
@ -187,20 +193,20 @@ class VectorDB(Generic[TSchema]):
executor['jcloud'] = executor_jcloud_config executor['jcloud'] = executor_jcloud_config
global_jcloud_config = { global_jcloud_config = {
'labels': { 'labels': {
'app': 'vectordb', 'app': 'vectordb',
}, },
'monitor': { 'monitor': {
'traces': { 'traces': {
'enable': True, 'enable': True,
}, },
'metrics': { 'metrics': {
'enable': True, 'enable': True,
'host': 'http://opentelemetry-collector.monitor.svc.cluster.local', 'host': 'http://opentelemetry-collector.monitor.svc.cluster.local',
'port': 4317, 'port': 4317,
}, },
}, },
} }
flow_dict['jcloud'] = global_jcloud_config flow_dict['jcloud'] = global_jcloud_config
import tempfile import tempfile
from jcloud.flow import CloudFlow from jcloud.flow import CloudFlow

View File

@ -3,6 +3,7 @@ import string
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import numpy as np
from vectordb.db.executors.typed_executor import TypedExecutor from vectordb.db.executors.typed_executor import TypedExecutor
from jina.serve.executors.decorators import requests, write from jina.serve.executors.decorators import requests, write
@ -13,12 +14,29 @@ if TYPE_CHECKING:
class HNSWLibIndexer(TypedExecutor): class HNSWLibIndexer(TypedExecutor):
def __init__(self, *args, **kwargs): def __init__(self,
space='l2',
max_elements=1024,
ef_construction=200,
ef=10,
M=16,
allow_replace_deleted=False,
num_threads=1,
*args, **kwargs):
from docarray.index import HnswDocumentIndex from docarray.index import HnswDocumentIndex
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
workspace = self.workspace.replace('[', '_').replace(']', '_') workspace = self.workspace.replace('[', '_').replace(']', '_')
self.work_dir = f'{workspace}' if self.handle_persistence else f'{workspace}/{"".join(random.choice(string.ascii_lowercase) for _ in range(5))}' self.work_dir = f'{workspace}' if self.handle_persistence else f'{workspace}/{"".join(random.choice(string.ascii_lowercase) for _ in range(5))}'
self._indexer = HnswDocumentIndex[self._input_schema](work_dir=self.work_dir) db_conf = HnswDocumentIndex.DBConfig()
db_conf.default_column_config.get(np.ndarray).update({'space': space,
'ef_construction': ef_construction,
'ef': ef,
'max_elements': max_elements,
'M': M,
'allow_replace_deleted': allow_replace_deleted,
'num_threads': num_threads})
db_conf.work_dir = self.work_dir
self._indexer = HnswDocumentIndex[self._input_schema](db_config=db_conf)
def _index(self, docs, *args, **kwargs): def _index(self, docs, *args, **kwargs):
self._indexer.index(docs) self._indexer.index(docs)

View File

@ -17,7 +17,8 @@ def pass_kwargs_as_params(func):
params[k] = kwargs[k] params[k] = kwargs[k]
if len(params.keys()) > 0: if len(params.keys()) > 0:
for k in params.keys(): for k in params.keys():
kwargs.pop(k) if k in kwargs:
kwargs.pop(k)
kwargs['parameters'] = params kwargs['parameters'] = params
return func(*args, **kwargs) return func(*args, **kwargs)