Skip to content

Commit

Permalink
Merge pull request #58 from venkat2469/r1.10.0
Browse files Browse the repository at this point in the history
Combine a few beam metrics to reduce the number of counters in a pipe…
  • Loading branch information
rtg0795 authored Aug 26, 2022
2 parents 319de43 + 56f0f34 commit 1909d48
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions tfx_bsl/telemetry/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from tensorflow_metadata.proto.v0 import schema_pb2

# TODO(b/68154497): remove this. # pylint: disable=no-value-for-parameter


def _IncrementCounter(element: int, counter_namespace: str,
counter_name: str) -> int:
Expand All @@ -30,19 +32,37 @@ def _IncrementCounter(element: int, counter_namespace: str,


@beam.ptransform_fn
def TrackRecordBatchBytes(dataset: beam.PCollection[pa.RecordBatch],
counter_namespace: str,
counter_name: str) -> beam.pvalue.PCollection[int]:
"""Gathers telemetry on input record batch."""
def ExtractRecordBatchBytes(
dataset: beam.PCollection[pa.RecordBatch]) -> beam.PCollection[int]:
"""Extracts the total bytes of the input PCollection of RecordBatch."""
return (dataset
| "GetRecordBatchSize" >> beam.Map(lambda rb: rb.nbytes)
| "SumTotalBytes" >> beam.CombineGlobally(sum)
| "SumTotalBytes" >> beam.CombineGlobally(sum))


@beam.ptransform_fn
def IncrementCounter(value: beam.PCollection[int], counter_namespace: str,
counter_name: str) -> beam.PCollection[int]:
"""Increments the given counter after summing the input values."""
return (value
| "SumCounterValue" >> beam.CombineGlobally(sum)
| "IncrementCounter" >> beam.Map(
_IncrementCounter,
counter_namespace=counter_namespace,
counter_name=counter_name))


@beam.ptransform_fn
def TrackRecordBatchBytes(dataset: beam.PCollection[pa.RecordBatch],
counter_namespace: str,
counter_name: str) -> beam.PCollection[int]:
"""Gathers telemetry on input record batch."""
return (dataset
| "GetRecordBatchSize" >> ExtractRecordBatchBytes()
| "IncrementCounter" >> IncrementCounter(
counter_namespace=counter_namespace, counter_name=counter_name))


def _IncrementTensorRepresentationCounters(
tensor_representations: Dict[str, schema_pb2.TensorRepresentation],
counter_namespace: str):
Expand Down

0 comments on commit 1909d48

Please sign in to comment.