Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support returning computed distance and set distance thresholds on VectorQueries #960

Merged
merged 16 commits into from
Aug 26, 2024
Merged
16 changes: 13 additions & 3 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,25 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
) -> AsyncVectorQuery:
"""
Finds the closest vector embeddings to the given query vector.

Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance
calculation. If unset then the distance will not be returned.
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.

Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
Expand All @@ -250,6 +258,8 @@ def find_nearest(
query_vector=query_vector,
limit=limit,
distance_measure=distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)

def count(
Expand Down
20 changes: 16 additions & 4 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,23 +550,35 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
) -> VectorQuery:
"""
Finds the closest vector embeddings to the given query vector.

Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance calculation
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.

Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
"""
return self._vector_query().find_nearest(
vector_field, query_vector, limit, distance_measure
vector_field,
query_vector,
limit,
distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)


Expand Down
3 changes: 3 additions & 0 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,9 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
) -> BaseVectorQuery:
raise NotImplementedError

Expand Down
14 changes: 14 additions & 0 deletions google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self, nested_query) -> None:
self._query_vector: Optional[Vector] = None
self._limit: Optional[int] = None
self._distance_measure: Optional[DistanceMeasure] = None
self._distance_result_field: Optional[str] = None
self._distance_threshold: Optional[float] = None

@property
def _client(self):
Expand All @@ -69,6 +71,11 @@ def _to_protobuf(self) -> query.StructuredQuery:
else:
raise ValueError("Invalid distance_measure")

# Coerce ints to floats as required by the protobuf.
distance_threshold_proto = None
if self._distance_threshold is not None:
distance_threshold_proto = float(self._distance_threshold)
NickChittle marked this conversation as resolved.
Show resolved Hide resolved

pb = self._nested_query._to_protobuf()
pb.find_nearest = query.StructuredQuery.FindNearest(
vector_field=query.StructuredQuery.FieldReference(
Expand All @@ -77,6 +84,8 @@ def _to_protobuf(self) -> query.StructuredQuery:
query_vector=_helpers.encode_value(self._query_vector),
distance_measure=distance_measure_proto,
limit=self._limit,
distance_result_field=self._distance_result_field,
distance_threshold=distance_threshold_proto,
)
return pb

Expand Down Expand Up @@ -111,12 +120,17 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
):
"""Finds the closest vector embeddings to the given query vector."""
self._vector_field = vector_field
self._query_vector = query_vector
self._limit = limit
self._distance_measure = distance_measure
self._distance_result_field = distance_result_field
self._distance_threshold = distance_threshold
return self

def stream(
Expand Down
17 changes: 14 additions & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,26 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
) -> Type["firestore_v1.vector_query.VectorQuery"]:
"""
Finds the closest vector embeddings to the given query vector.

Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance
calculation. If unset then the distance will not be returned.
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.


Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
Expand All @@ -271,6 +280,8 @@ def find_nearest(
query_vector=query_vector,
limit=limit,
distance_measure=distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)

def count(
Expand Down
Loading
Loading