Skip to content

Commit

Permalink
Hot fix Atlas vector-search
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 28, 2023
1 parent a7b32f3 commit 4584103
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 32 deletions.
30 changes: 14 additions & 16 deletions examples/vector_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
"cell_type": "markdown",
"id": "f283b5675bea4619",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"## Prerequisites\n",
Expand Down Expand Up @@ -109,17 +112,6 @@
"doc_collection = Collection('documents')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f41b3a35-760e-49aa-8387-6a5efb990ea5",
"metadata": {},
"outputs": [],
"source": [
"# Overall metadata information\n",
"db.metadata"
]
},
{
"cell_type": "markdown",
"id": "6bb5ee2b-f0bb-4660-961d-fdf98833f33d",
Expand Down Expand Up @@ -152,7 +144,10 @@
"cell_type": "markdown",
"id": "420ef3662c07d91e",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"As usual, we insert the data:"
Expand Down Expand Up @@ -215,7 +210,10 @@
"cell_type": "markdown",
"id": "7e8d1d264dd7ba1b",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"For Sentence-Transformers vectors, uncomment the following section:"
Expand Down Expand Up @@ -316,7 +314,7 @@
"display(Markdown('---'))\n",
"\n",
"# Iterate through the query results and display them\n",
"for r in result:\n",
"for r in sorted(result, key=lambda r: -r['score']):\n",
" # Display the document's parent and res values in a formatted way\n",
" display(Markdown(f'### `{r[\"parent\"] + \".\" if r[\"parent\"] else \"\"}{r[\"res\"]}`'))\n",
" \n",
Expand Down Expand Up @@ -344,7 +342,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
35 changes: 22 additions & 13 deletions superduperdb/backends/mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,34 @@ def create_vector_index(self, vector_index, dry_run=False):
if re.match('^_outputs\.[A-Za-z0-9_]+\.[A-Za-z0-9_]+', key):
key = key.split('.')[1]
model = vector_index.indexing_listener.model.identifier
fields = {
model: [
version = vector_index.indexing_listener.model.version
fields4 = {
str(version): [
{
"dimensions": vector_index.dimensions,
"similarity": vector_index.measure,
"type": "knnVector",
}
]
}
fields3 = {
model: {
"fields": fields4,
"type": "document",
}
}
fields2 = {
key: {
"fields": fields3,
"type": "document",
}
}
fields1 = {
"_outputs": {
"fields": fields2,
"type": "document",
}
}
index_definition = {
"createSearchIndexes": collection,
"indexes": [
Expand All @@ -139,17 +158,7 @@ def create_vector_index(self, vector_index, dry_run=False):
"definition": {
"mappings": {
"dynamic": True,
"fields": {
"_outputs": {
"fields": {
key: {
"fields": fields,
"type": "document",
}
},
"type": "document",
}
},
"fields": fields1,
}
},
}
Expand Down
5 changes: 4 additions & 1 deletion superduperdb/backends/mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,10 @@ def _replace_document_with_vector(step, vector_index, db):
if indexing_key.startswith('_outputs'):
indexing_key = indexing_key.split('.')[1]
indexing_model = vector_index.indexing_listener.model.identifier
step['$vectorSearch']['path'] = f'_outputs.{indexing_key}.{indexing_model}'
indexing_version = vector_index.indexing_listener.model.version
step['$vectorSearch'][
'path'
] = f'_outputs.{indexing_key}.{indexing_model}.{indexing_version}'
step['$vectorSearch']['index'] = vector_index.identifier
del step['$vectorSearch']['like']
return step
Expand Down
9 changes: 8 additions & 1 deletion superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def build_artifact_store(cfg):
def build(uri, mapping):
logging.debug(f"Parsing data connection URI:{uri}")

if re.match('^mongodb:\/\/|^mongodb\+srv:\/\/', uri) is not None:
if re.match('^mongodb:\/\/', uri) is not None:
name = uri.split('/')[-1]
conn = pymongo.MongoClient(
uri,
serverSelectionTimeoutMS=5000,
)
return mapping['mongodb'](conn, name)

elif re.match('^mongodb\+srv:\/\/', uri):
name = uri.split('/')[-1]
conn = pymongo.MongoClient(
'/'.join(uri.split('/')[:-1]),
serverSelectionTimeoutMS=5000,
)
return mapping['mongodb'](conn, name)
elif uri.startswith('mongomock://'):
name = uri.split('/')[-1]
Expand Down
3 changes: 2 additions & 1 deletion superduperdb/components/vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class VectorIndex(Component):
type_id: t.ClassVar[str] = 'vector_index'

@override
def pre_create(self, db: Datalayer) -> None:
def post_create(self, db: Datalayer) -> None:
super().post_create(db)
if s.CFG.vector_search == s.CFG.data_backend:
if (create := getattr(db.databackend, 'create_vector_index', None)) is None:
msg = 'VectorIndex is not supported by the current database backend'
Expand Down

0 comments on commit 4584103

Please sign in to comment.