Compare commits
1 Commits
main
...
feat-pass-
Author | SHA1 | Date |
---|---|---|
Joan Fontanals Martinez | a898c8028b |
|
@ -26,7 +26,8 @@ def docs_to_index():
|
|||
@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
|
||||
def test_hnswlib_vectordb_batch(docs_to_index, replicas, shards, protocol, tmpdir):
|
||||
query = docs_to_index[:10]
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol) as db:
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol,
|
||||
uses_with={'ef': 5000}) as db:
|
||||
db.index(inputs=docs_to_index)
|
||||
if replicas > 1:
|
||||
time.sleep(2)
|
||||
|
@ -45,11 +46,12 @@ def test_hnswlib_vectordb_batch(docs_to_index, replicas, shards, protocol, tmpdi
|
|||
@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
|
||||
def test_hnswlib_vectordb_single_query(docs_to_index, limit, replicas, shards, protocol, tmpdir):
|
||||
query = docs_to_index[100]
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol) as db:
|
||||
with HNSWVectorDB[MyDoc](ef=5000).serve(workspace=str(tmpdir), replicas=replicas, shards=shards,
|
||||
protocol=protocol) as db:
|
||||
db.index(inputs=docs_to_index)
|
||||
if replicas > 1:
|
||||
time.sleep(2)
|
||||
resp = db.search(inputs=query)
|
||||
resp = db.search(inputs=query, limit=limit)
|
||||
assert len(resp.matches) == min(limit * shards, len(docs_to_index))
|
||||
assert resp.id == resp.matches[0].id
|
||||
assert resp.text == resp.matches[0].text
|
||||
|
@ -62,7 +64,8 @@ def test_hnswlib_vectordb_single_query(docs_to_index, limit, replicas, shards, p
|
|||
def test_hnswlib_vectordb_delete(docs_to_index, replicas, shards, protocol, tmpdir):
|
||||
query = docs_to_index[0]
|
||||
delete = MyDoc(id=query.id, text='', embedding=np.random.rand(128))
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol) as db:
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol,
|
||||
uses_with={'ef': 5000}) as db:
|
||||
db.index(inputs=docs_to_index)
|
||||
if replicas > 1:
|
||||
time.sleep(2)
|
||||
|
@ -89,7 +92,8 @@ def test_hnswlib_vectordb_delete(docs_to_index, replicas, shards, protocol, tmpd
|
|||
def test_hnswlib_vectordb_udpate_text(docs_to_index, replicas, shards, protocol, tmpdir):
|
||||
query = docs_to_index[0]
|
||||
update = MyDoc(id=query.id, text=query.text + '_changed', embedding=query.embedding)
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol) as db:
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards, protocol=protocol,
|
||||
uses_with={'ef': 5000}) as db:
|
||||
db.index(inputs=docs_to_index)
|
||||
if replicas > 1:
|
||||
time.sleep(2)
|
||||
|
@ -115,8 +119,8 @@ def test_hnswlib_vectordb_udpate_text(docs_to_index, replicas, shards, protocol,
|
|||
def test_hnswlib_vectordb_restore(docs_to_index, replicas, shards, protocol, tmpdir):
|
||||
query = docs_to_index[:100]
|
||||
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards,
|
||||
protocol=protocol) as db:
|
||||
with HNSWVectorDB[MyDoc](ef=5000).serve(workspace=str(tmpdir), replicas=replicas, shards=shards,
|
||||
protocol=protocol) as db:
|
||||
db.index(docs=docs_to_index)
|
||||
if replicas > 1:
|
||||
time.sleep(2)
|
||||
|
@ -129,7 +133,7 @@ def test_hnswlib_vectordb_restore(docs_to_index, replicas, shards, protocol, tmp
|
|||
assert res.scores[0] < 0.001 # some precision issues, should be 0.0
|
||||
|
||||
with HNSWVectorDB[MyDoc].serve(workspace=str(tmpdir), replicas=replicas, shards=shards,
|
||||
protocol=protocol) as new_db:
|
||||
protocol=protocol, uses_with={'ef': 5000}) as new_db:
|
||||
time.sleep(2)
|
||||
resp = new_db.search(docs=query)
|
||||
assert len(resp) == len(query)
|
||||
|
|
|
@ -22,7 +22,7 @@ def docs_to_index():
|
|||
@pytest.mark.parametrize('call_method', ['docs', 'inputs', 'positional'])
|
||||
def test_hnswlib_vectordb_batch(docs_to_index, call_method, tmpdir):
|
||||
query = docs_to_index[:10]
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir))
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
|
||||
if call_method == 'docs':
|
||||
indexer.index(docs=docs_to_index)
|
||||
resp = indexer.search(docs=query)
|
||||
|
@ -44,7 +44,7 @@ def test_hnswlib_vectordb_batch(docs_to_index, call_method, tmpdir):
|
|||
@pytest.mark.parametrize('call_method', ['docs', 'inputs', 'positional'])
|
||||
def test_hnswlib_vectordb_single_query(docs_to_index, limit, call_method, tmpdir):
|
||||
query = docs_to_index[100]
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir))
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
|
||||
if call_method == 'docs':
|
||||
indexer.index(docs=docs_to_index)
|
||||
resp = indexer.search(docs=query, limit=limit)
|
||||
|
@ -70,7 +70,7 @@ def test_hnswlib_vectordb_search_field(tmpdir):
|
|||
for _ in range(2000)])
|
||||
|
||||
query = docs_to_index[100]
|
||||
indexer = HNSWVectorDB[MyDocTens](workspace=str(tmpdir))
|
||||
indexer = HNSWVectorDB[MyDocTens](workspace=str(tmpdir), ef=5000)
|
||||
indexer.index(docs=docs_to_index)
|
||||
resp = indexer.search(docs=query, search_field='tens')
|
||||
assert len(resp.matches) == 10
|
||||
|
@ -83,7 +83,7 @@ def test_hnswlib_vectordb_search_field(tmpdir):
|
|||
def test_hnswlib_vectordb_delete(docs_to_index, call_method, tmpdir):
|
||||
query = docs_to_index[0]
|
||||
delete = MyDoc(id=query.id, text='', embedding=np.random.rand(128))
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir))
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
|
||||
if call_method == 'docs':
|
||||
indexer.index(docs=docs_to_index)
|
||||
resp = indexer.search(docs=query)
|
||||
|
@ -118,7 +118,7 @@ def test_hnswlib_vectordb_delete(docs_to_index, call_method, tmpdir):
|
|||
def test_hnswlib_vectordb_udpate_text(docs_to_index, call_method, tmpdir):
|
||||
query = docs_to_index[0]
|
||||
update = MyDoc(id=query.id, text=query.text + '_changed', embedding=query.embedding)
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir))
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
|
||||
if call_method == 'docs':
|
||||
indexer.index(docs=docs_to_index)
|
||||
resp = indexer.search(docs=query)
|
||||
|
@ -152,7 +152,7 @@ def test_hnswlib_vectordb_udpate_text(docs_to_index, call_method, tmpdir):
|
|||
|
||||
def test_hnswlib_vectordb_restore(docs_to_index, tmpdir):
|
||||
query = docs_to_index[:100]
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir))
|
||||
indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
|
||||
indexer.index(docs=docs_to_index)
|
||||
resp = indexer.search(docs=query)
|
||||
assert len(resp) == len(query)
|
||||
|
@ -162,7 +162,7 @@ def test_hnswlib_vectordb_restore(docs_to_index, tmpdir):
|
|||
# assert res.text == res.matches[0].text
|
||||
# assert res.scores[0] < 0.001 # some precision issues, should be 0
|
||||
indexer.persist()
|
||||
new_indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir))
|
||||
new_indexer = HNSWVectorDB[MyDoc](workspace=str(tmpdir), ef=5000)
|
||||
resp = new_indexer.search(docs=query)
|
||||
assert len(resp) == len(query)
|
||||
for res in resp:
|
||||
|
|
|
@ -47,6 +47,7 @@ class VectorDB(Generic[TSchema]):
|
|||
return VectorDBTyped
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._workspace = None
|
||||
if 'work_dir' in kwargs:
|
||||
self._workspace = kwargs['work_dir']
|
||||
if 'workspace' in kwargs:
|
||||
|
@ -66,13 +67,16 @@ class VectorDB(Generic[TSchema]):
|
|||
shards: Optional[int] = None,
|
||||
replicas: Optional[int] = None,
|
||||
peer_ports: Optional[Union[Dict[str, List], List]] = None,
|
||||
uses_with: Optional[Dict] = None,
|
||||
definition_file: Optional[str] = None,
|
||||
obj_name: Optional[str] = None,
|
||||
**kwargs):
|
||||
from jina import Deployment, Flow
|
||||
is_instance = False
|
||||
uses_with = uses_with or {}
|
||||
if isinstance(cls, VectorDB):
|
||||
is_instance = True
|
||||
uses_with = uses_with.update(**cls._uses_with)
|
||||
|
||||
if is_instance:
|
||||
workspace = workspace or cls._workspace
|
||||
|
@ -134,6 +138,7 @@ class VectorDB(Generic[TSchema]):
|
|||
if use_deployment:
|
||||
jina_object = Deployment(name='indexer',
|
||||
uses=uses,
|
||||
uses_with=uses_with,
|
||||
port=port,
|
||||
protocol=protocol,
|
||||
shards=shards,
|
||||
|
@ -145,6 +150,7 @@ class VectorDB(Generic[TSchema]):
|
|||
else:
|
||||
jina_object = Flow(port=port, protocol=protocol, **kwargs).add(name='indexer',
|
||||
uses=uses,
|
||||
uses_with=uses_with,
|
||||
shards=shards,
|
||||
replicas=replicas,
|
||||
stateful=stateful,
|
||||
|
@ -187,20 +193,20 @@ class VectorDB(Generic[TSchema]):
|
|||
executor['jcloud'] = executor_jcloud_config
|
||||
|
||||
global_jcloud_config = {
|
||||
'labels': {
|
||||
'app': 'vectordb',
|
||||
},
|
||||
'monitor': {
|
||||
'traces': {
|
||||
'enable': True,
|
||||
},
|
||||
'metrics': {
|
||||
'enable': True,
|
||||
'host': 'http://opentelemetry-collector.monitor.svc.cluster.local',
|
||||
'port': 4317,
|
||||
},
|
||||
},
|
||||
}
|
||||
'labels': {
|
||||
'app': 'vectordb',
|
||||
},
|
||||
'monitor': {
|
||||
'traces': {
|
||||
'enable': True,
|
||||
},
|
||||
'metrics': {
|
||||
'enable': True,
|
||||
'host': 'http://opentelemetry-collector.monitor.svc.cluster.local',
|
||||
'port': 4317,
|
||||
},
|
||||
},
|
||||
}
|
||||
flow_dict['jcloud'] = global_jcloud_config
|
||||
import tempfile
|
||||
from jcloud.flow import CloudFlow
|
||||
|
|
|
@ -3,6 +3,7 @@ import string
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from vectordb.db.executors.typed_executor import TypedExecutor
|
||||
from jina.serve.executors.decorators import requests, write
|
||||
|
||||
|
@ -13,12 +14,29 @@ if TYPE_CHECKING:
|
|||
|
||||
class HNSWLibIndexer(TypedExecutor):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self,
|
||||
space='l2',
|
||||
max_elements=1024,
|
||||
ef_construction=200,
|
||||
ef=10,
|
||||
M=16,
|
||||
allow_replace_deleted=False,
|
||||
num_threads=1,
|
||||
*args, **kwargs):
|
||||
from docarray.index import HnswDocumentIndex
|
||||
super().__init__(*args, **kwargs)
|
||||
workspace = self.workspace.replace('[', '_').replace(']', '_')
|
||||
self.work_dir = f'{workspace}' if self.handle_persistence else f'{workspace}/{"".join(random.choice(string.ascii_lowercase) for _ in range(5))}'
|
||||
self._indexer = HnswDocumentIndex[self._input_schema](work_dir=self.work_dir)
|
||||
db_conf = HnswDocumentIndex.DBConfig()
|
||||
db_conf.default_column_config.get(np.ndarray).update({'space': space,
|
||||
'ef_construction': ef_construction,
|
||||
'ef': ef,
|
||||
'max_elements': max_elements,
|
||||
'M': M,
|
||||
'allow_replace_deleted': allow_replace_deleted,
|
||||
'num_threads': num_threads})
|
||||
db_conf.work_dir = self.work_dir
|
||||
self._indexer = HnswDocumentIndex[self._input_schema](db_config=db_conf)
|
||||
|
||||
def _index(self, docs, *args, **kwargs):
|
||||
self._indexer.index(docs)
|
||||
|
|
|
@ -17,7 +17,8 @@ def pass_kwargs_as_params(func):
|
|||
params[k] = kwargs[k]
|
||||
if len(params.keys()) > 0:
|
||||
for k in params.keys():
|
||||
kwargs.pop(k)
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
kwargs['parameters'] = params
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue