Skip to content

Commit

Permalink
feat: query profiling part 1: synchronous (#938)
Browse files Browse the repository at this point in the history
* feat: support query profiling

* collection

* fix unit tests

* unit tests

* vector get and stream, unit tests

* aggregation get and stream, unit tests

* docstring

* query profile unit tests

* update base classes' method signature

* documentsnapshotlist unit tests

* func signatures

* undo client.py change

* transaction.get()

* lint

* system test

* fix shim test

* fix sys test

* fix sys test

* system test

* another system test

* skip system test in emulator

* stream generator unit tests

* coverage

* add system tests

* small fixes

* undo document change

* add system tests

* vector query system tests

* format

* fix system test

* comments

* add system tests

* improve stream generator

* type checking

* adding stars

* delete comment

* remove coverage requirements for type checking part

* add explain_options to StreamGenerator

* yield tuple instead

* raise exception when explain_metrics is absent

* refactor documentsnapshotlist into queryresultslist

* add comment

* improve type hint

* lint

* move QueryResultsList to stream_generator.py

* aggregation related type annotation

* transaction return type hint

* refactor QueryResultsList

* change stream generator to return ExplainMetrics instead of yield

* update aggregation query to use the new generator

* update query to use the new generator

* update vector query to use the new generator

* lint

* type annotations

* fix type annotation to be python 3.9 compatible

* fix type hint for python 3.8

* fix system test

* add test coverage

* use class method get_explain_metrics() instead of property explain_metrics

* address comments

* remove more Optional

* add type hint for async stream generator

* simplify yield in aggregation stream

* stream generator type annotation

* more type hints

* remove "Integer"

* docstring format

* mypy

* add more input verification for query_results.py
  • Loading branch information
Linchin authored Sep 6, 2024
1 parent 3a546d3 commit 1614b3f
Show file tree
Hide file tree
Showing 31 changed files with 2,274 additions and 176 deletions.
2 changes: 2 additions & 0 deletions google/cloud/firestore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from google.cloud.firestore_v1 import DocumentSnapshot
from google.cloud.firestore_v1 import DocumentTransform
from google.cloud.firestore_v1 import ExistsOption
from google.cloud.firestore_v1 import ExplainOptions
from google.cloud.firestore_v1 import FieldFilter
from google.cloud.firestore_v1 import GeoPoint
from google.cloud.firestore_v1 import Increment
Expand Down Expand Up @@ -78,6 +79,7 @@
"DocumentSnapshot",
"DocumentTransform",
"ExistsOption",
"ExplainOptions",
"FieldFilter",
"GeoPoint",
"Increment",
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/firestore_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from google.cloud.firestore_v1.collection import CollectionReference
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1.query import CollectionGroup, Query
from google.cloud.firestore_v1.query_profile import ExplainOptions
from google.cloud.firestore_v1.transaction import Transaction, transactional
from google.cloud.firestore_v1.transforms import (
DELETE_FIELD,
Expand Down Expand Up @@ -131,6 +132,7 @@
"DocumentSnapshot",
"DocumentTransform",
"ExistsOption",
"ExplainOptions",
"FieldFilter",
"GeoPoint",
"Increment",
Expand Down
90 changes: 73 additions & 17 deletions google/cloud/firestore_v1/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
BaseAggregationQuery,
_query_response_to_result,
)
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.query_results import QueryResultsList
from google.cloud.firestore_v1.stream_generator import StreamGenerator

# Types needed only for Type Hints
if TYPE_CHECKING:
from google.cloud.firestore_v1 import transaction # pragma: NO COVER
if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1 import transaction
from google.cloud.firestore_v1.query_profile import ExplainMetrics
from google.cloud.firestore_v1.query_profile import ExplainOptions


class AggregationQuery(BaseAggregationQuery):
Expand All @@ -54,10 +56,14 @@ def get(
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> List[AggregationResult]:
*,
explain_options: Optional[ExplainOptions] = None,
) -> QueryResultsList[AggregationResult]:
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages.
This sends a ``RunAggregationQuery`` RPC and returns a list of
aggregation results in the stream of ``RunAggregationQueryResponse``
messages.
Args:
transaction
Expand All @@ -70,20 +76,39 @@ def get(
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
list: The aggregation query results
QueryResultsList[AggregationResult]: The aggregation query results.
"""
result = self.stream(transaction=transaction, retry=retry, timeout=timeout)
return list(result) # type: ignore
explain_metrics: ExplainMetrics | None = None

def _get_stream_iterator(self, transaction, retry, timeout):
result = self.stream(
transaction=transaction,
retry=retry,
timeout=timeout,
explain_options=explain_options,
)
result_list = list(result)

if explain_options is None:
explain_metrics = None
else:
explain_metrics = result.get_explain_metrics()

return QueryResultsList(result_list, explain_options, explain_metrics)

def _get_stream_iterator(self, transaction, retry, timeout, explain_options=None):
"""Helper method for :meth:`stream`."""
request, kwargs = self._prep_stream(
transaction,
retry,
timeout,
explain_options,
)

return self._client._firestore_api.run_aggregation_query(
Expand All @@ -106,9 +131,12 @@ def _retry_query_after_exception(self, exc, retry, transaction):
def _make_stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT,
retry: Union[
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> Union[Generator[List[AggregationResult], Any, None]]:
explain_options: Optional[ExplainOptions] = None,
) -> Generator[List[AggregationResult], Any, Optional[ExplainMetrics]]:
"""Internal method for stream(). Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns a generator
Expand All @@ -127,16 +155,27 @@ def _make_stream(
system-specified policy.
timeout (Optional[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Yields:
:class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`:
List[AggregationResult]:
The result of aggregations of this query.
Returns:
(Optional[google.cloud.firestore_v1.types.query_profile.ExplainMetrtics]):
The results of query profiling, if received from the service.
"""
metrics: ExplainMetrics | None = None

response_iterator = self._get_stream_iterator(
transaction,
retry,
timeout,
explain_options,
)
while True:
try:
Expand All @@ -154,15 +193,26 @@ def _make_stream(

if response is None: # EOI
break

if metrics is None and response.explain_metrics:
metrics = response.explain_metrics

result = _query_response_to_result(response)
yield result
if result:
yield result

return metrics

def stream(
self,
transaction: Optional["transaction.Transaction"] = None,
retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT,
retry: Union[
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> "StreamGenerator[DocumentSnapshot]":
*,
explain_options: Optional[ExplainOptions] = None,
) -> StreamGenerator[List[AggregationResult]]:
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and then returns a generator
Expand All @@ -181,13 +231,19 @@ def stream(
system-specified policy.
timeout (Optinal[float]): The timeout for this request. Defaults
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
`StreamGenerator[DocumentSnapshot]`: A generator of the query results.
`StreamGenerator[List[AggregationResult]]`:
A generator of the query results.
"""
inner_generator = self._make_stream(
transaction=transaction,
retry=retry,
timeout=timeout,
explain_options=explain_options,
)
return StreamGenerator(inner_generator)
return StreamGenerator(inner_generator, explain_options)
4 changes: 2 additions & 2 deletions google/cloud/firestore_v1/async_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def get(
retries.AsyncRetry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> List[AggregationResult]:
) -> List[List[AggregationResult]]:
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages.
Expand All @@ -71,7 +71,7 @@ async def get(
system-specified value.
Returns:
list: The aggregation query results
List[List[AggregationResult]]: The aggregation query results
"""
stream_result = self.stream(
Expand Down
19 changes: 11 additions & 8 deletions google/cloud/firestore_v1/async_stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,28 @@
Firestore API.
"""

from collections import abc
from typing import Any, AsyncGenerator, Awaitable, TypeVar


class AsyncStreamGenerator(abc.AsyncGenerator):
T = TypeVar("T")


class AsyncStreamGenerator(AsyncGenerator[T, Any]):
"""Asynchronous generator for the streamed results."""

def __init__(self, response_generator):
def __init__(self, response_generator: AsyncGenerator[T, Any]):
self._generator = response_generator

def __aiter__(self):
return self._generator
def __aiter__(self) -> AsyncGenerator[T, Any]:
return self

def __anext__(self):
def __anext__(self) -> Awaitable[T]:
return self._generator.__anext__()

def asend(self, value=None):
def asend(self, value=None) -> Awaitable[Any]:
return self._generator.asend(value)

def athrow(self, exp=None):
def athrow(self, exp=None) -> Awaitable[Any]:
return self._generator.athrow(exp)

def aclose(self):
Expand Down
64 changes: 40 additions & 24 deletions google/cloud/firestore_v1/base_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,7 @@

import abc
from abc import ABC
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Coroutine,
Generator,
List,
Optional,
Tuple,
Union,
)
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union

from google.api_core import gapic_v1
from google.api_core import retry as retries
Expand All @@ -47,8 +37,14 @@
)

# Types needed only for Type Hints
if TYPE_CHECKING:
from google.cloud.firestore_v1 import transaction # pragma: NO COVER
if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1 import transaction
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
from google.cloud.firestore_v1.query_profile import ExplainOptions
from google.cloud.firestore_v1.query_results import QueryResultsList
from google.cloud.firestore_v1.stream_generator import (
StreamGenerator,
)


class AggregationResult(object):
Expand All @@ -62,7 +58,7 @@ class AggregationResult(object):
:param value: The resulting read_time
"""

def __init__(self, alias: str, value: int, read_time=None):
def __init__(self, alias: str, value: float, read_time=None):
self.alias = alias
self.value = value
self.read_time = read_time
Expand Down Expand Up @@ -211,13 +207,16 @@ def _prep_stream(
transaction=None,
retry: Union[retries.Retry, None, gapic_v1.method._MethodDefault] = None,
timeout: float | None = None,
explain_options: Optional[ExplainOptions] = None,
) -> Tuple[dict, dict]:
parent_path, expected_prefix = self._collection_ref._parent_info()
request = {
"parent": parent_path,
"structured_aggregation_query": self._to_protobuf(),
"transaction": _helpers.get_transaction_id(transaction),
}
if explain_options:
request["explain_options"] = explain_options._to_dict()
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)

return request, kwargs
Expand All @@ -230,10 +229,17 @@ def get(
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: float | None = None,
) -> List[AggregationResult] | Coroutine[Any, Any, List[AggregationResult]]:
*,
explain_options: Optional[ExplainOptions] = None,
) -> (
QueryResultsList[AggregationResult]
| Coroutine[Any, Any, List[List[AggregationResult]]]
):
"""Runs the aggregation query.
This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages.
This sends a ``RunAggregationQuery`` RPC and returns a list of
aggregation results in the stream of ``RunAggregationQueryResponse``
messages.
Args:
transaction
Expand All @@ -246,22 +252,27 @@ def get(
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
list: The aggregation query results
(QueryResultsList[List[AggregationResult]] | Coroutine[Any, Any, List[List[AggregationResult]]]):
The aggregation query results.
"""

@abc.abstractmethod
def stream(
self,
transaction: Optional[transaction.Transaction] = None,
retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT,
retry: Union[
retries.Retry, None, gapic_v1.method._MethodDefault
] = gapic_v1.method.DEFAULT,
timeout: Optional[float] = None,
) -> (
Generator[List[AggregationResult], Any, None]
| AsyncGenerator[List[AggregationResult], None]
):
*,
explain_options: Optional[ExplainOptions] = None,
) -> StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator:
"""Runs the aggregation query.
This sends a``RunAggregationQuery`` RPC and returns a generator in the stream of ``RunAggregationQueryResponse`` messages.
Expand All @@ -274,8 +285,13 @@ def stream(
errors, if any, should be retried. Defaults to a
system-specified policy.
timeout (Optinal[float]): The timeout for this request. Defaults
to a system-specified value.
to a system-specified value.
explain_options
(Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
Options to enable query profiling for this query. When set,
explain_metrics will be available on the returned generator.
Returns:
StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator:
A generator of the query results.
"""
Loading

0 comments on commit 1614b3f

Please sign in to comment.