From 843dc05177e2cacd6792eaa935f4c259844ea3fc Mon Sep 17 00:00:00 2001 From: Linchin Date: Mon, 12 Aug 2024 23:26:21 -0700 Subject: [PATCH] update vector query to use the new generator --- google/cloud/firestore_v1/vector_query.py | 18 ++- tests/unit/v1/test_vector_query.py | 188 +++++++++++++--------- 2 files changed, 123 insertions(+), 83 deletions(-) diff --git a/google/cloud/firestore_v1/vector_query.py b/google/cloud/firestore_v1/vector_query.py index d57009fd0..5d978c3d8 100644 --- a/google/cloud/firestore_v1/vector_query.py +++ b/google/cloud/firestore_v1/vector_query.py @@ -85,6 +85,8 @@ def get( Returns: QueryResultsList[DocumentSnapshot]: The vector query results. """ + explain_metrics: ExplainMetrics | None = None + result = self.stream( transaction=transaction, retry=retry, @@ -149,9 +151,15 @@ def _make_stream( explain_metrics will be available on the returned generator. Yields: - Tuple[Optional[DocumentSnapshot], Optional[ExplainMetrics]]: + Optional[DocumentSnapshot]: The next document that fulfills the query. + + Returns: + ([google.cloud.firestore_v1.types.query_profile.ExplainMetrtics | None]): + The results of query profiling, if received from the service. """ + metrics: ExplainMetrics | None = None + response_iterator, expected_prefix = self._get_stream_iterator( transaction, retry, @@ -165,8 +173,8 @@ def _make_stream( if response is None: # EOI break - if response.explain_metrics: - yield None, response.explain_metrics + if metrics is None and response.explain_metrics: + metrics = response.explain_metrics if self._nested_query._all_descendants: snapshot = _collection_group_query_response_to_snapshot( @@ -177,7 +185,9 @@ def _make_stream( response, self._nested_query._parent, expected_prefix ) if snapshot is not None: - yield snapshot, None + yield snapshot + + return metrics def stream( self, diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index 5131ee5d3..1fb06b714 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -17,7 +17,12 @@ from google.cloud.firestore_v1._helpers import encode_value, make_retry_timeout_kwargs from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.query_profile import ExplainOptions, QueryExplainError +from google.cloud.firestore_v1.query_profile import ( + ExplainMetrics, + ExplainOptions, + QueryExplainError, +) +from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.types.query import StructuredQuery from google.cloud.firestore_v1.vector import Vector from tests.unit.v1._test_helpers import make_client, make_query, make_vector_query @@ -105,21 +110,7 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit): return expected_pb -@pytest.mark.parametrize( - "distance_measure, expected_distance", - [ - ( - DistanceMeasure.EUCLIDEAN, - StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, - ), - (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), - ( - DistanceMeasure.DOT_PRODUCT, - StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, - ), - ], -) -def test_vector_query(distance_measure, expected_distance): +def _vector_query_get_helper(distance_measure, expected_distance, explain_options=None): # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) client = make_client() @@ -130,8 +121,14 @@ def test_vector_query(distance_measure, expected_distance): parent_path, expected_prefix = parent._parent_info() data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + if explain_options is not None: + explain_metrics = {"execution_stats": {"results_returned": 1}} + else: + explain_metrics = None response_pb = _make_query_response( - name="{}/test_doc".format(expected_prefix), data=data + name="{}/test_doc".format(expected_prefix), + data=data, + explain_metrics=explain_metrics, ) kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) @@ -146,15 +143,19 @@ def test_vector_query(distance_measure, expected_distance): limit=5, ) - returned = vector_query.get(transaction=_transaction(client), **kwargs) - assert isinstance(returned, list) + returned = vector_query.get( + transaction=_transaction(client), **kwargs, explain_options=explain_options + ) + assert isinstance(returned, QueryResultsList) assert len(returned) == 1 assert returned[0].to_dict() == data - with pytest.raises( - QueryExplainError, - match="explain_options not set on query", - ): - returned.explain_metrics + + if explain_options is None: + with pytest.raises(QueryExplainError, match="explain_options not set"): + returned.explain_metrics + else: + assert isinstance(returned.explain_metrics, ExplainMetrics) + assert returned.explain_metrics.execution_stats.results_returned == 1 expected_pb = _expected_pb( parent=parent, @@ -163,73 +164,102 @@ def test_vector_query(distance_measure, expected_distance): distance_type=expected_distance, limit=5, ) + expected_request = { + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + } + if explain_options is not None: + expected_request["explain_options"] = explain_options._to_dict() firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": expected_pb, - "transaction": _TXN_ID, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) -def test_vector_query_w_explain_options(): - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) - client = make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("dee") - parent_path, expected_prefix = parent._parent_info() - - data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} - response_pb = _make_query_response( - name="{}/test_doc".format(expected_prefix), - data=data, - explain_metrics={"execution_stats": {"results_returned": 1}}, +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +def test_vector_query(distance_measure, expected_distance): + _vector_query_get_helper( + distance_measure=distance_measure, expected_distance=expected_distance ) - kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) - - # Execute the vector query and check the response. - firestore_api.run_query.return_value = iter([response_pb]) - vector_query = parent.find_nearest( - vector_field="embedding", - query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, - limit=5, - ) +def test_vector_query_w_explain_options(): explain_options = ExplainOptions(analyze=True) - returned = vector_query.get( - transaction=_transaction(client), - **kwargs, + _vector_query_get_helper( + distance_measure=DistanceMeasure.EUCLIDEAN, + expected_distance=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, explain_options=explain_options, ) - assert isinstance(returned, list) - assert len(returned) == 1 - assert returned[0].to_dict() == data - assert returned.explain_metrics is not None - - expected_pb = _expected_pb( - parent=parent, - vector_field="embedding", - vector=Vector([1.0, 2.0, 3.0]), - distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, - limit=5, - ) - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": expected_pb, - "transaction": _TXN_ID, - "explain_options": explain_options._to_dict(), - }, - metadata=client._rpc_metadata, - **kwargs, - ) + # # Create a minimal fake GAPIC. + # firestore_api = mock.Mock(spec=["run_query"]) + # client = make_client() + # client._firestore_api_internal = firestore_api + + # # Make a **real** collection reference as parent. + # parent = client.collection("dee") + # parent_path, expected_prefix = parent._parent_info() + + # data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + # response_pb = _make_query_response( + # name="{}/test_doc".format(expected_prefix), + # data=data, + # explain_metrics={"execution_stats": {"results_returned": 1}}, + # ) + + # kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # # Execute the vector query and check the response. + # firestore_api.run_query.return_value = iter([response_pb]) + # vector_query = parent.find_nearest( + # vector_field="embedding", + # query_vector=Vector([1.0, 2.0, 3.0]), + # distance_measure=DistanceMeasure.EUCLIDEAN, + # limit=5, + # ) + + # explain_options = ExplainOptions(analyze=True) + # returned = vector_query.get( + # transaction=_transaction(client), + # **kwargs, + # explain_options=explain_options, + # ) + # assert isinstance(returned, QueryResultsList) + # assert len(returned) == 1 + # assert returned[0].to_dict() == data + # assert returned.explain_metrics is not None + + # expected_pb = _expected_pb( + # parent=parent, + # vector_field="embedding", + # vector=Vector([1.0, 2.0, 3.0]), + # distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + # limit=5, + # ) + # firestore_api.run_query.assert_called_once_with( + # request={ + # "parent": parent_path, + # "structured_query": expected_pb, + # "transaction": _TXN_ID, + # "explain_options": explain_options._to_dict(), + # }, + # metadata=client._rpc_metadata, + # **kwargs, + # ) @pytest.mark.parametrize(