Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support returning computed distance and set distance thresholds on VectorQueries #960

Merged
merged 16 commits into from
Aug 26, 2024
Merged
16 changes: 12 additions & 4 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Type
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Type, Union

from google.api_core import gapic_v1
from google.api_core import retry_async as retries
Expand Down Expand Up @@ -230,17 +230,23 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
) -> AsyncVectorQuery:
"""
Finds the closest vector embeddings to the given query vector.

Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance calculation
NickChittle marked this conversation as resolved.
Show resolved Hide resolved
distance_threshold (Optional[Union[int, float]]):
A threshold for which no less similar documents will be returned.

Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
Expand All @@ -250,6 +256,8 @@ def find_nearest(
query_vector=query_vector,
limit=limit,
distance_measure=distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)

def count(
Expand Down
19 changes: 15 additions & 4 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,23 +550,34 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
NickChittle marked this conversation as resolved.
Show resolved Hide resolved
) -> VectorQuery:
"""
Finds the closest vector embeddings to the given query vector.

Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance calculation
distance_threshold (Optional[Union[int, float]]):
A threshold for which no less similar documents will be returned.

Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
"""
return self._vector_query().find_nearest(
vector_field, query_vector, limit, distance_measure
vector_field,
query_vector,
limit,
distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)


Expand Down
2 changes: 2 additions & 0 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,8 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
) -> BaseVectorQuery:
raise NotImplementedError

Expand Down
13 changes: 13 additions & 0 deletions google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self, nested_query) -> None:
self._query_vector: Optional[Vector] = None
self._limit: Optional[int] = None
self._distance_measure: Optional[DistanceMeasure] = None
self._distance_result_field: Optional[str] = None
self._distance_threshold: Optional[Union[int, float]] = None
NickChittle marked this conversation as resolved.
Show resolved Hide resolved

@property
def _client(self):
Expand All @@ -69,6 +71,11 @@ def _to_protobuf(self) -> query.StructuredQuery:
else:
raise ValueError("Invalid distance_measure")

# Coerce ints to floats as required by the protobuf.
distance_threshold_proto = None
if self._distance_threshold:
distance_threshold_proto = float(self._distance_threshold)
NickChittle marked this conversation as resolved.
Show resolved Hide resolved

pb = self._nested_query._to_protobuf()
pb.find_nearest = query.StructuredQuery.FindNearest(
vector_field=query.StructuredQuery.FieldReference(
Expand All @@ -77,6 +84,8 @@ def _to_protobuf(self) -> query.StructuredQuery:
query_vector=_helpers.encode_value(self._query_vector),
distance_measure=distance_measure_proto,
limit=self._limit,
distance_result_field=self._distance_result_field,
distance_threshold=distance_threshold_proto,
)
return pb

Expand Down Expand Up @@ -111,12 +120,16 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
):
"""Finds the closest vector embeddings to the given query vector."""
self._vector_field = vector_field
self._query_vector = query_vector
self._limit = limit
self._distance_measure = distance_measure
self._distance_result_field = distance_result_field
self._distance_threshold = distance_threshold
return self

def stream(
Expand Down
17 changes: 13 additions & 4 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type
from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type, Union

from google.api_core import exceptions, gapic_v1
from google.api_core import retry as retries
Expand Down Expand Up @@ -251,17 +251,24 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
) -> Type["firestore_v1.vector_query.VectorQuery"]:
"""
Finds the closest vector embeddings to the given query vector.

Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance calculation
distance_threshold (Optional[Union[int, float]]):
A threshold for which no less similar documents will be returned.


Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
Expand All @@ -271,6 +278,8 @@ def find_nearest(
query_vector=query_vector,
limit=limit,
distance_measure=distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)

def count(
Expand Down
68 changes: 64 additions & 4 deletions tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def on_snapshot(docs, changes, read_time):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

Expand All @@ -199,7 +199,7 @@ def test_vector_search_collection(client, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

Expand All @@ -218,10 +218,40 @@ def test_vector_search_collection_with_filter(client, database):
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_with_distance_parameters(client, database):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

vector_query = collection.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=3,
distance_result_field="vector_distance",
distance_threshold=1.0,
)
returned = vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 0.0,
}
assert returned[1].to_dict() == {
"embedding": Vector([2.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 1.0,
}

NickChittle marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

Expand All @@ -243,7 +273,7 @@ def test_vector_search_collection_group(client, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

Expand All @@ -262,6 +292,36 @@ def test_vector_search_collection_group_with_filter(client, database):
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group_with_distance_parameters(client, database):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=3,
distance_result_field="vector_distance",
distance_threshold=1.0,
)
returned = vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 0.0,
}
assert returned[1].to_dict() == {
"embedding": Vector([2.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 1.0,
}


@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_create_document_w_subcollection(client, cleanup, database):
collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID
Expand Down
70 changes: 66 additions & 4 deletions tests/system/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ async def test_document_update_w_int_field(client, cleanup, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)
vector_query = collection.find_nearest(
Expand All @@ -363,7 +363,7 @@ async def test_vector_search_collection(client, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)
vector_query = collection.where("color", "==", "red").find_nearest(
Expand All @@ -381,10 +381,40 @@ async def test_vector_search_collection_with_filter(client, database):
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_with_distance_parameters(client, database):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

vector_query = collection.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=3,
distance_result_field="vector_distance",
distance_threshold=1.0,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 0.0,
}
assert returned[1].to_dict() == {
"embedding": Vector([2.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 1.0,
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

Expand All @@ -406,7 +436,7 @@ async def test_vector_search_collection_group(client, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

Expand All @@ -425,6 +455,38 @@ async def test_vector_search_collection_group_with_filter(client, database):
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group_with_distance_parameters(
client, database
):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=3,
distance_result_field="vector_distance",
distance_threshold=1.0,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 0.0,
}
assert returned[1].to_dict() == {
"embedding": Vector([2.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 1.0,
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_update_document(client, cleanup, database):
Expand Down
Loading
Loading