diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 15f81be24..ca83c2630 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -230,17 +230,25 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ) -> AsyncVectorQuery: """ Finds the closest vector embeddings to the given query vector. Args: - vector_field(str): An indexed vector field to search upon. Only documents which contain + vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. - distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. + distance_result_field (Optional[str]): + Name of the field to output the result of the vector distance + calculation. If unset then the distance will not be returned. + distance_threshold (Optional[float]): + A threshold for which no less similar documents will be returned. Returns: :class`~firestore_v1.vector_query.VectorQuery`: the vector query. @@ -250,6 +258,8 @@ def find_nearest( query_vector=query_vector, limit=limit, distance_measure=distance_measure, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) def count( diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index e2065dc2f..18c62aa33 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -550,23 +550,35 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ) -> VectorQuery: """ Finds the closest vector embeddings to the given query vector. Args: - vector_field(str): An indexed vector field to search upon. Only documents which contain + vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. - distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. + distance_result_field (Optional[str]): + Name of the field to output the result of the vector distance calculation + distance_threshold (Optional[float]): + A threshold for which no less similar documents will be returned. Returns: :class`~firestore_v1.vector_query.VectorQuery`: the vector query. """ return self._vector_query().find_nearest( - vector_field, query_vector, limit, distance_measure + vector_field, + query_vector, + limit, + distance_measure, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 73ed00206..cfed454b9 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -982,6 +982,9 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ) -> BaseVectorQuery: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index 0c5c61b3e..26cd5b199 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -45,6 +45,8 @@ def __init__(self, nested_query) -> None: self._query_vector: Optional[Vector] = None self._limit: Optional[int] = None self._distance_measure: Optional[DistanceMeasure] = None + self._distance_result_field: Optional[str] = None + self._distance_threshold: Optional[float] = None @property def _client(self): @@ -69,6 +71,11 @@ def _to_protobuf(self) -> query.StructuredQuery: else: raise ValueError("Invalid distance_measure") + # Coerce ints to floats as required by the protobuf. + distance_threshold_proto = None + if self._distance_threshold is not None: + distance_threshold_proto = float(self._distance_threshold) + pb = self._nested_query._to_protobuf() pb.find_nearest = query.StructuredQuery.FindNearest( vector_field=query.StructuredQuery.FieldReference( @@ -77,6 +84,8 @@ def _to_protobuf(self) -> query.StructuredQuery: query_vector=_helpers.encode_value(self._query_vector), distance_measure=distance_measure_proto, limit=self._limit, + distance_result_field=self._distance_result_field, + distance_threshold=distance_threshold_proto, ) return pb @@ -111,12 +120,17 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ): """Finds the closest vector embeddings to the given query vector.""" self._vector_field = vector_field self._query_vector = query_vector self._limit = limit self._distance_measure = distance_measure + self._distance_result_field = distance_result_field + self._distance_threshold = distance_threshold return self def stream( diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index b5bd5ec4f..eb8f51dc8 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -251,17 +251,26 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ) -> Type["firestore_v1.vector_query.VectorQuery"]: """ Finds the closest vector embeddings to the given query vector. Args: - vector_field(str): An indexed vector field to search upon. Only documents which contain + vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. - distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. + distance_result_field (Optional[str]): + Name of the field to output the result of the vector distance + calculation. If unset then the distance will not be returned. + distance_threshold (Optional[float]): + A threshold for which no less similar documents will be returned. + Returns: :class`~firestore_v1.vector_query.VectorQuery`: the vector query. @@ -271,6 +280,8 @@ def find_nearest( query_vector=query_vector, limit=limit, distance_measure=distance_measure, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) def count( diff --git a/tests/system/test_system.py b/tests/system/test_system.py index dc9d86a10..b67b8aecc 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -176,15 +176,22 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -198,15 +205,22 @@ def test_vector_search_collection(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_with_filter(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection_with_filter(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -220,15 +234,82 @@ def test_vector_search_collection_with_filter(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_group(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +def test_vector_search_collection_with_distance_parameters_euclid(client, database): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=3, + distance_result_field="vector_distance", + distance_threshold=1.0, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([2.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 1.0, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_with_distance_parameters_cosine(client, database): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection_group(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) vector_query = collection_group.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -241,16 +322,23 @@ def test_vector_search_collection_group(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_group_with_filter(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +def test_vector_search_collection_group_with_filter(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) vector_query = collection_group.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -262,6 +350,70 @@ def test_vector_search_collection_group_with_filter(client, database): } +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_group_with_distance_parameters_euclid( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=3, + distance_result_field="vector_distance", + distance_threshold=1.0, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([2.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 1.0, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_group_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index df574e0fa..78bd64c5c 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -341,15 +341,22 @@ async def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), limit=1, - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, ) returned = await vector_query.get() assert isinstance(returned, list) @@ -362,15 +369,22 @@ async def test_vector_search_collection(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_with_filter(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_with_filter(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), limit=1, - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, ) returned = await vector_query.get() assert isinstance(returned, list) @@ -383,15 +397,86 @@ async def test_vector_search_collection_with_filter(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_group(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +async def test_vector_search_collection_with_distance_parameters_euclid( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=3, + distance_result_field="vector_distance", + distance_threshold=1.0, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([2.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 1.0, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_group(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) vector_query = collection_group.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = await vector_query.get() @@ -405,15 +490,24 @@ async def test_vector_search_collection_group(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_group_with_filter(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_group_with_filter( + client, database, distance_measure +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) vector_query = collection_group.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = await vector_query.get() @@ -425,6 +519,70 @@ async def test_vector_search_collection_group_with_filter(client, database): } +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group_with_distance_parameters_euclid( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=3, + distance_result_field="vector_distance", + distance_threshold=1.0, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([2.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 1.0, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_update_document(client, cleanup, database): diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index 340ccb30e..564ec32bc 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -108,6 +108,12 @@ def make_vector_query(*args, **kw): return VectorQuery(*args, **kw) +def make_async_vector_query(*args, **kw): + from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery + + return AsyncVectorQuery(*args, **kw) + + def build_test_timestamp( year: int = 2021, month: int = 1, diff --git a/tests/unit/v1/test_async_vector_query.py b/tests/unit/v1/test_async_vector_query.py index 8b2a95a26..390190b53 100644 --- a/tests/unit/v1/test_async_vector_query.py +++ b/tests/unit/v1/test_async_vector_query.py @@ -18,7 +18,12 @@ from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.types.query import StructuredQuery from google.cloud.firestore_v1.vector import Vector -from tests.unit.v1._test_helpers import make_async_client, make_async_query, make_query +from tests.unit.v1._test_helpers import ( + make_async_client, + make_async_query, + make_async_vector_query, + make_query, +) from tests.unit.v1.test__helpers import AsyncIter, AsyncMock from tests.unit.v1.test_base_query import _make_query_response @@ -33,7 +38,15 @@ def _transaction(client): return transaction -def _expected_pb(parent, vector_field, vector, distance_type, limit): +def _expected_pb( + parent, + vector_field, + vector, + distance_type, + limit, + distance_result_field=None, + distance_threshold=None, +): query = make_query(parent) expected_pb = query._to_protobuf() expected_pb.find_nearest = StructuredQuery.FindNearest( @@ -41,10 +54,40 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit): query_vector=encode_value(vector.to_map_value()), distance_measure=distance_type, limit=limit, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) return expected_pb +def test_async_vector_query_int_threshold_constructor_to_pb(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + vector_query = make_async_vector_query(query) + + assert vector_query._nested_query == query + assert vector_query._client == query._parent._client + + vector_query.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + distance_threshold=5, + ) + + expected_pb = query._to_protobuf() + expected_pb.find_nearest = StructuredQuery.FindNearest( + vector_field=StructuredQuery.FieldReference(field_path="embedding"), + query_vector=encode_value(Vector([1.0, 2.0, 3.0]).to_map_value()), + distance_measure=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + limit=5, + distance_threshold=5.0, + ) + assert vector_query._to_protobuf() == expected_pb + + @pytest.mark.parametrize( "distance_measure, expected_distance", [ @@ -188,6 +231,154 @@ async def test_async_vector_query_with_filter(distance_measure, expected_distanc ) +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_async_vector_query_with_distance_result_field( + distance_measure, expected_distance +): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.5]), "vector_distance": 0.5} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + vector_async__query = query.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + distance_result_field="vector_distance", + ) + + returned = await vector_async__query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + distance_result_field="vector_distance", + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_async_vector_query_with_distance_threshold( + distance_measure, expected_distance +): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.5])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + vector_async__query = query.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + distance_threshold=125.5, + ) + + returned = await vector_async__query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + distance_threshold=125.5, + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + @pytest.mark.parametrize( "distance_measure, expected_distance", [ diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index beb094141..a5b1d342b 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -54,6 +54,8 @@ def test_vector_query_constructor_to_pb(distance_measure, expected_distance): query_vector=Vector([1.0, 2.0, 3.0]), distance_measure=distance_measure, limit=5, + distance_result_field="vector_distance", + distance_threshold=125.5, ) expected_pb = query._to_protobuf() @@ -62,6 +64,36 @@ def test_vector_query_constructor_to_pb(distance_measure, expected_distance): query_vector=encode_value(Vector([1.0, 2.0, 3.0]).to_map_value()), distance_measure=expected_distance, limit=5, + distance_result_field="vector_distance", + distance_threshold=125.5, + ) + assert vector_query._to_protobuf() == expected_pb + + +def test_vector_query_int_threshold_constructor_to_pb(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + vector_query = make_vector_query(query) + + assert vector_query._nested_query == query + assert vector_query._client == query._parent._client + + vector_query.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + distance_threshold=5, + ) + + expected_pb = query._to_protobuf() + expected_pb.find_nearest = StructuredQuery.FindNearest( + vector_field=StructuredQuery.FieldReference(field_path="embedding"), + query_vector=encode_value(Vector([1.0, 2.0, 3.0]).to_map_value()), + distance_measure=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + limit=5, + distance_threshold=5.0, ) assert vector_query._to_protobuf() == expected_pb @@ -92,7 +124,15 @@ def _transaction(client): return transaction -def _expected_pb(parent, vector_field, vector, distance_type, limit): +def _expected_pb( + parent, + vector_field, + vector, + distance_type, + limit, + distance_result_field=None, + distance_threshold=None, +): query = make_query(parent) expected_pb = query._to_protobuf() expected_pb.find_nearest = StructuredQuery.FindNearest( @@ -100,6 +140,8 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit): query_vector=encode_value(vector.to_map_value()), distance_measure=distance_type, limit=limit, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) return expected_pb @@ -168,6 +210,138 @@ def test_vector_query(distance_measure, expected_distance): ) +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +def test_vector_query_with_distance_result_field(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.5]), "vector_distance": 0.5} + response_pb = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = iter([response_pb]) + + vector_query = parent.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + distance_result_field="vector_distance", + ) + + returned = vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + distance_result_field="vector_distance", + ) + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +def test_vector_query_with_distance_threshold(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.5])} + response_pb = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = iter([response_pb]) + + vector_query = parent.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + distance_threshold=0.75, + ) + + returned = vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + distance_threshold=0.75, + ) + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + @pytest.mark.parametrize( "distance_measure, expected_distance", [