Skip to content

Commit

Permalink
update vector query to use the new generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Linchin committed Aug 13, 2024
1 parent ccbb623 commit 843dc05
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 83 deletions.
18 changes: 14 additions & 4 deletions google/cloud/firestore_v1/vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def get(
Returns:
QueryResultsList[DocumentSnapshot]: The vector query results.
"""
explain_metrics: ExplainMetrics | None = None

result = self.stream(
transaction=transaction,
retry=retry,
Expand Down Expand Up @@ -149,9 +151,15 @@ def _make_stream(
explain_metrics will be available on the returned generator.
Yields:
Tuple[Optional[DocumentSnapshot], Optional[ExplainMetrics]]:
Optional[DocumentSnapshot]:
The next document that fulfills the query.
Returns:
([google.cloud.firestore_v1.types.query_profile.ExplainMetrtics | None]):
The results of query profiling, if received from the service.
"""
metrics: ExplainMetrics | None = None

response_iterator, expected_prefix = self._get_stream_iterator(
transaction,
retry,
Expand All @@ -165,8 +173,8 @@ def _make_stream(
if response is None: # EOI
break

if response.explain_metrics:
yield None, response.explain_metrics
if metrics is None and response.explain_metrics:
metrics = response.explain_metrics

if self._nested_query._all_descendants:
snapshot = _collection_group_query_response_to_snapshot(
Expand All @@ -177,7 +185,9 @@ def _make_stream(
response, self._nested_query._parent, expected_prefix
)
if snapshot is not None:
yield snapshot, None
yield snapshot

return metrics

def stream(
self,
Expand Down
188 changes: 109 additions & 79 deletions tests/unit/v1/test_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

from google.cloud.firestore_v1._helpers import encode_value, make_retry_timeout_kwargs
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
from google.cloud.firestore_v1.query_profile import ExplainOptions, QueryExplainError
from google.cloud.firestore_v1.query_profile import (
ExplainMetrics,
ExplainOptions,
QueryExplainError,
)
from google.cloud.firestore_v1.query_results import QueryResultsList
from google.cloud.firestore_v1.types.query import StructuredQuery
from google.cloud.firestore_v1.vector import Vector
from tests.unit.v1._test_helpers import make_client, make_query, make_vector_query
Expand Down Expand Up @@ -105,21 +110,7 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit):
return expected_pb


@pytest.mark.parametrize(
"distance_measure, expected_distance",
[
(
DistanceMeasure.EUCLIDEAN,
StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
),
(DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE),
(
DistanceMeasure.DOT_PRODUCT,
StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT,
),
],
)
def test_vector_query(distance_measure, expected_distance):
def _vector_query_get_helper(distance_measure, expected_distance, explain_options=None):
# Create a minimal fake GAPIC.
firestore_api = mock.Mock(spec=["run_query"])
client = make_client()
Expand All @@ -130,8 +121,14 @@ def test_vector_query(distance_measure, expected_distance):
parent_path, expected_prefix = parent._parent_info()

data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
if explain_options is not None:
explain_metrics = {"execution_stats": {"results_returned": 1}}
else:
explain_metrics = None
response_pb = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
name="{}/test_doc".format(expected_prefix),
data=data,
explain_metrics=explain_metrics,
)

kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)
Expand All @@ -146,15 +143,19 @@ def test_vector_query(distance_measure, expected_distance):
limit=5,
)

returned = vector_query.get(transaction=_transaction(client), **kwargs)
assert isinstance(returned, list)
returned = vector_query.get(
transaction=_transaction(client), **kwargs, explain_options=explain_options
)
assert isinstance(returned, QueryResultsList)
assert len(returned) == 1
assert returned[0].to_dict() == data
with pytest.raises(
QueryExplainError,
match="explain_options not set on query",
):
returned.explain_metrics

if explain_options is None:
with pytest.raises(QueryExplainError, match="explain_options not set"):
returned.explain_metrics
else:
assert isinstance(returned.explain_metrics, ExplainMetrics)
assert returned.explain_metrics.execution_stats.results_returned == 1

expected_pb = _expected_pb(
parent=parent,
Expand All @@ -163,73 +164,102 @@ def test_vector_query(distance_measure, expected_distance):
distance_type=expected_distance,
limit=5,
)
expected_request = {
"parent": parent_path,
"structured_query": expected_pb,
"transaction": _TXN_ID,
}
if explain_options is not None:
expected_request["explain_options"] = explain_options._to_dict()
firestore_api.run_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_query": expected_pb,
"transaction": _TXN_ID,
},
request=expected_request,
metadata=client._rpc_metadata,
**kwargs,
)


def test_vector_query_w_explain_options():
# Create a minimal fake GAPIC.
firestore_api = mock.Mock(spec=["run_query"])
client = make_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")
parent_path, expected_prefix = parent._parent_info()

data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
response_pb = _make_query_response(
name="{}/test_doc".format(expected_prefix),
data=data,
explain_metrics={"execution_stats": {"results_returned": 1}},
@pytest.mark.parametrize(
"distance_measure, expected_distance",
[
(
DistanceMeasure.EUCLIDEAN,
StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
),
(DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE),
(
DistanceMeasure.DOT_PRODUCT,
StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT,
),
],
)
def test_vector_query(distance_measure, expected_distance):
_vector_query_get_helper(
distance_measure=distance_measure, expected_distance=expected_distance
)

kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)

# Execute the vector query and check the response.
firestore_api.run_query.return_value = iter([response_pb])
vector_query = parent.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=5,
)

def test_vector_query_w_explain_options():
explain_options = ExplainOptions(analyze=True)
returned = vector_query.get(
transaction=_transaction(client),
**kwargs,
_vector_query_get_helper(
distance_measure=DistanceMeasure.EUCLIDEAN,
expected_distance=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
explain_options=explain_options,
)
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == data
assert returned.explain_metrics is not None

expected_pb = _expected_pb(
parent=parent,
vector_field="embedding",
vector=Vector([1.0, 2.0, 3.0]),
distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
limit=5,
)
firestore_api.run_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_query": expected_pb,
"transaction": _TXN_ID,
"explain_options": explain_options._to_dict(),
},
metadata=client._rpc_metadata,
**kwargs,
)
# # Create a minimal fake GAPIC.
# firestore_api = mock.Mock(spec=["run_query"])
# client = make_client()
# client._firestore_api_internal = firestore_api

# # Make a **real** collection reference as parent.
# parent = client.collection("dee")
# parent_path, expected_prefix = parent._parent_info()

# data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
# response_pb = _make_query_response(
# name="{}/test_doc".format(expected_prefix),
# data=data,
# explain_metrics={"execution_stats": {"results_returned": 1}},
# )

# kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)

# # Execute the vector query and check the response.
# firestore_api.run_query.return_value = iter([response_pb])
# vector_query = parent.find_nearest(
# vector_field="embedding",
# query_vector=Vector([1.0, 2.0, 3.0]),
# distance_measure=DistanceMeasure.EUCLIDEAN,
# limit=5,
# )

# explain_options = ExplainOptions(analyze=True)
# returned = vector_query.get(
# transaction=_transaction(client),
# **kwargs,
# explain_options=explain_options,
# )
# assert isinstance(returned, QueryResultsList)
# assert len(returned) == 1
# assert returned[0].to_dict() == data
# assert returned.explain_metrics is not None

# expected_pb = _expected_pb(
# parent=parent,
# vector_field="embedding",
# vector=Vector([1.0, 2.0, 3.0]),
# distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
# limit=5,
# )
# firestore_api.run_query.assert_called_once_with(
# request={
# "parent": parent_path,
# "structured_query": expected_pb,
# "transaction": _TXN_ID,
# "explain_options": explain_options._to_dict(),
# },
# metadata=client._rpc_metadata,
# **kwargs,
# )


@pytest.mark.parametrize(
Expand Down

0 comments on commit 843dc05

Please sign in to comment.