Skip to content

Commit

Permalink
feat: support returning computed distance and set distance thresholds…
Browse files Browse the repository at this point in the history
… on VectorQueries (#960)
  • Loading branch information
NickChittle authored Aug 26, 2024
1 parent 53b8aab commit 5c2192d
Show file tree
Hide file tree
Showing 10 changed files with 768 additions and 37 deletions.
16 changes: 13 additions & 3 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,25 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
) -> AsyncVectorQuery:
"""
Finds the closest vector embeddings to the given query vector.
Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance
calculation. If unset then the distance will not be returned.
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.
Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
Expand All @@ -250,6 +258,8 @@ def find_nearest(
query_vector=query_vector,
limit=limit,
distance_measure=distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)

def count(
Expand Down
20 changes: 16 additions & 4 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,23 +550,35 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
) -> VectorQuery:
"""
Finds the closest vector embeddings to the given query vector.
Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance calculation
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.
Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
"""
return self._vector_query().find_nearest(
vector_field, query_vector, limit, distance_measure
vector_field,
query_vector,
limit,
distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)


Expand Down
3 changes: 3 additions & 0 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,9 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
) -> BaseVectorQuery:
raise NotImplementedError

Expand Down
14 changes: 14 additions & 0 deletions google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self, nested_query) -> None:
self._query_vector: Optional[Vector] = None
self._limit: Optional[int] = None
self._distance_measure: Optional[DistanceMeasure] = None
self._distance_result_field: Optional[str] = None
self._distance_threshold: Optional[float] = None

@property
def _client(self):
Expand All @@ -69,6 +71,11 @@ def _to_protobuf(self) -> query.StructuredQuery:
else:
raise ValueError("Invalid distance_measure")

# Coerce ints to floats as required by the protobuf.
distance_threshold_proto = None
if self._distance_threshold is not None:
distance_threshold_proto = float(self._distance_threshold)

pb = self._nested_query._to_protobuf()
pb.find_nearest = query.StructuredQuery.FindNearest(
vector_field=query.StructuredQuery.FieldReference(
Expand All @@ -77,6 +84,8 @@ def _to_protobuf(self) -> query.StructuredQuery:
query_vector=_helpers.encode_value(self._query_vector),
distance_measure=distance_measure_proto,
limit=self._limit,
distance_result_field=self._distance_result_field,
distance_threshold=distance_threshold_proto,
)
return pb

Expand Down Expand Up @@ -111,12 +120,17 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
):
"""Finds the closest vector embeddings to the given query vector."""
self._vector_field = vector_field
self._query_vector = query_vector
self._limit = limit
self._distance_measure = distance_measure
self._distance_result_field = distance_result_field
self._distance_threshold = distance_threshold
return self

def stream(
Expand Down
17 changes: 14 additions & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,26 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
) -> Type["firestore_v1.vector_query.VectorQuery"]:
"""
Finds the closest vector embeddings to the given query vector.
Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance
calculation. If unset then the distance will not be returned.
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.
Returns:
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
Expand All @@ -271,6 +280,8 @@ def find_nearest(
query_vector=query_vector,
limit=limit,
distance_measure=distance_measure,
distance_result_field=distance_result_field,
distance_threshold=distance_threshold,
)

def count(
Expand Down
Loading

0 comments on commit 5c2192d

Please sign in to comment.