Skip to content

Commit

Permalink
Fixed Post-Like feature
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Apr 1, 2024
1 parent c078ac4 commit 46a900e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed some bugs of the cdc RAG application
- Fixed open source RAG Pipeline
- Fixed vllm real-time task concurrency bug
- Fixed Post-Like feature

## [0.1.1](https://github.com/SuperDuperDB/superduperdb/compare/0.0.20...0.1.0]) (2023-Feb-09)

Expand Down
1 change: 1 addition & 0 deletions superduperdb/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def schedule_jobs(
select=self.select.copy(),
dependencies=dependencies,
overwrite=overwrite,
**(self.predict_kwargs or {}),
)
]
return out
Expand Down
9 changes: 7 additions & 2 deletions superduperdb/vector_search/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,13 @@ def find_nearest_from_array(self, h, n=100, within_ids=None):
similarities = similarities[0, :]
logging.debug(similarities)
scores = -numpy.sort(-similarities)
ix = numpy.argsort(-similarities)[:n]
ix = ix.tolist()
## different ways of handling
if within_ids:
top_n_idxs = numpy.argsort(-similarities)[:n]
ix = [ix[i] for i in top_n_idxs]
else:
ix = numpy.argsort(-similarities)[:n]
ix = ix.tolist()
scores = scores.tolist()
_ids = [self.index[i] for i in ix]
return _ids, scores
Expand Down
15 changes: 12 additions & 3 deletions superduperdb/vector_search/lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import pyarrow as pa

from superduperdb import CFG
from superduperdb.vector_search.base import BaseVectorSearcher, VectorItem
from superduperdb.vector_search.base import (
BaseVectorSearcher,
VectorIndexMeasureType,
VectorItem,
)


class LanceVectorSearcher(BaseVectorSearcher):
Expand All @@ -30,7 +34,9 @@ def __init__(
):
self.dataset_path = os.path.join(CFG.lance_home, f'{identifier}.lance')
self.dimensions = dimensions
self.measure = measure
self.measure = (
measure.name if isinstance(measure, VectorIndexMeasureType) else measure
)
if h is not None:
if not os.path.exists(self.dataset_path):
os.makedirs(self.dataset_path, exist_ok=True)
Expand Down Expand Up @@ -98,9 +104,11 @@ def find_nearest_from_array(
# NOTE: filter is currently applied AFTER vector-search
# See https://lancedb.github.io/lance/api/python/lance.html#lance.dataset.LanceDataset.scanner
if within_ids:
if isinstance(within_ids, (list, set)):
within_ids = tuple(within_ids)
assert (
type(within_ids) == tuple
), 'within_ids must be a tuple for lance sql parser'
), 'within_ids must be a [tuple | list | set] for lance sql parser'
result = self.dataset.to_table(
columns=['id'],
nearest={
Expand All @@ -110,6 +118,7 @@ def find_nearest_from_array(
'metric': self.measure,
},
filter=f"id in {within_ids}",
prefilter=True,
offset=0,
)
else:
Expand Down

0 comments on commit 46a900e

Please sign in to comment.