Skip to content

Commit

Permalink
async collection
Browse files Browse the repository at this point in the history
  • Loading branch information
Linchin committed Sep 3, 2024
1 parent 4c95727 commit 750ea98
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 12 deletions.
24 changes: 22 additions & 2 deletions google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Classes for representing collections for the Google Cloud Firestore API."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple

Expand All @@ -35,6 +36,8 @@
if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.query_profile import ExplainOptions
from google.cloud.firestore_v1.query_results import QueryResultsList


class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]):
Expand Down Expand Up @@ -192,7 +195,9 @@ async def get(
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> list:
*,
explain_options: Optional[ExplainOptions] = None,
) -> QueryResultsList[DocumentSnapshot]:
"""Read the documents in this collection.
This sends a ``RunQuery`` RPC and returns a list of documents
Expand All @@ -207,14 +212,21 @@ async def get(
system-specified policy.
timeout (Otional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
If a ``transaction`` is used and it already has write operations added,
this method cannot be used (i.e. read-after-write is not allowed).
Returns:
list: The documents in this collection that match the query.
QueryResultsList[DocumentSnapshot]:
The documents in this collection that match the query.
"""
query, kwargs = self._prep_get_or_stream(retry, timeout)
if explain_options is not None:
kwargs["explain_options"] = explain_options

return await query.get(transaction=transaction, **kwargs)

Expand All @@ -223,6 +235,8 @@ def stream(
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
*,
explain_options: Optional[ExplainOptions] = None,
) -> "AsyncStreamGenerator[DocumentSnapshot]":
"""Read the documents in this collection.
Expand Down Expand Up @@ -250,11 +264,17 @@ def stream(
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
`AsyncStreamGenerator[DocumentSnapshot]`: A generator of the query
results.
"""
query, kwargs = self._prep_get_or_stream(retry, timeout)
if explain_options:
kwargs["explain_options"] = explain_options

return query.stream(transaction=transaction, **kwargs)
6 changes: 2 additions & 4 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,7 @@ def stream(
timeout: Optional[float] = None,
*,
explain_options: Optional[ExplainOptions] = None,
) -> AsyncGenerator[
[async_document.DocumentSnapshot | query_profile_pb.ExplainMetrics], Any
]:
) -> AsyncGenerator[async_document.DocumentSnapshot, Any]:
"""Read the documents in the collection that match this query.
This sends a ``RunQuery`` RPC and then returns a generator which
Expand Down Expand Up @@ -436,7 +434,7 @@ def stream(
explain_metrics will be available on the returned generator.
Returns:
`AsyncGenerator[[async_document.DocumentSnapshot | query_profile_pb.ExplainMetrics], Any]`:
`AsyncGenerator[async_document.DocumentSnapshot, Any]`:
An asynchronous generator of the queryresults.
"""
inner_generator = self._make_stream(
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/v1/test_async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,51 @@ async def test_asynccollectionreference_stream_with_transaction(query_class):
query_instance.stream.assert_called_once_with(transaction=transaction)


@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True)
@pytest.mark.asyncio
async def test_asynccollectionreference_stream_w_explain_options(query_class):
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.query_profile import (
ExplainMetrics,
ExplainOptions,
QueryExplainError,
)
import google.cloud.firestore_v1.types.query_profile as query_profile_pb2

explain_options = ExplainOptions(analyze=True)
explain_metrics = query_profile_pb2.ExplainMetrics(
{"execution_stats": {"results_returned": 1}}
)

async def response_generator():
for item in [1, 2, 3, explain_metrics]:
yield item

query_class.return_value.stream.return_value = AsyncStreamGenerator(
response_generator(), explain_options
)

collection = _make_async_collection_reference("collection")
stream_response = collection.stream(explain_options=ExplainOptions(analyze=True))
assert isinstance(stream_response, AsyncStreamGenerator)

with pytest.raises(QueryExplainError, match="explain_metrics not available"):
await stream_response.get_explain_metrics()

async for _ in stream_response:
pass

query_class.assert_called_once_with(collection)
query_instance = query_class.return_value
query_instance.stream.assert_called_once_with(
transaction=None, explain_options=explain_options
)

explain_metrics = await stream_response.get_explain_metrics()
assert isinstance(explain_metrics, ExplainMetrics)
assert explain_metrics.execution_stats.results_returned == 1


def test_asynccollectionreference_recursive():
from google.cloud.firestore_v1.async_query import AsyncQuery

Expand Down
12 changes: 6 additions & 6 deletions tests/unit/v1/test_async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ async def _stream_helper(retry=None, timeout=None, explain_options=None):
# Execute the query and check the response.
query = make_async_query(parent)

get_response = query.stream(**kwargs, explain_options=explain_options)
assert isinstance(get_response, AsyncStreamGenerator)
stream_response = query.stream(**kwargs, explain_options=explain_options)
assert isinstance(stream_response, AsyncStreamGenerator)

returned = [x async for x in get_response]
returned = [x async for x in stream_response]
assert len(returned) == 1
snapshot = returned[0]
assert snapshot.reference._path == ("dee", "sleep")
Expand All @@ -379,9 +379,9 @@ async def _stream_helper(retry=None, timeout=None, explain_options=None):
# Verify explain_metrics.
if explain_options is None:
with pytest.raises(QueryExplainError, match="explain_options not set"):
await get_response.get_explain_metrics()
await stream_response.get_explain_metrics()
else:
explain_metrics = await get_response.get_explain_metrics()
explain_metrics = await stream_response.get_explain_metrics()
assert isinstance(explain_metrics, ExplainMetrics)
assert explain_metrics.execution_stats.results_returned == 1

Expand Down Expand Up @@ -685,7 +685,7 @@ async def test_asyncquery_stream_w_collection_group():


@pytest.mark.asyncio
async def test_query_stream_w_explain_options():
async def test_asyncquery_stream_w_explain_options():
from google.cloud.firestore_v1.query_profile import ExplainOptions

explain_options = ExplainOptions(analyze=True)
Expand Down

0 comments on commit 750ea98

Please sign in to comment.