From 56f0f345298d27ec87a36ed5c37dea7fbd94493f Mon Sep 17 00:00:00 2001 From: zoy Date: Wed, 24 Aug 2022 06:41:04 -0700 Subject: [PATCH] Combine a few beam metrics to reduce the number of counters in a pipeline. PiperOrigin-RevId: 469712221 --- tfx_bsl/telemetry/collection.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/tfx_bsl/telemetry/collection.py b/tfx_bsl/telemetry/collection.py index 823b42d1..e77dc845 100644 --- a/tfx_bsl/telemetry/collection.py +++ b/tfx_bsl/telemetry/collection.py @@ -21,6 +21,8 @@ from tensorflow_metadata.proto.v0 import schema_pb2 +# TODO(b/68154497): remove this. # pylint: disable=no-value-for-parameter + def _IncrementCounter(element: int, counter_namespace: str, counter_name: str) -> int: @@ -30,19 +32,37 @@ def _IncrementCounter(element: int, counter_namespace: str, @beam.ptransform_fn -def TrackRecordBatchBytes(dataset: beam.PCollection[pa.RecordBatch], - counter_namespace: str, - counter_name: str) -> beam.pvalue.PCollection[int]: - """Gathers telemetry on input record batch.""" +def ExtractRecordBatchBytes( + dataset: beam.PCollection[pa.RecordBatch]) -> beam.PCollection[int]: + """Extracts the total bytes of the input PCollection of RecordBatch.""" return (dataset | "GetRecordBatchSize" >> beam.Map(lambda rb: rb.nbytes) - | "SumTotalBytes" >> beam.CombineGlobally(sum) + | "SumTotalBytes" >> beam.CombineGlobally(sum)) + + +@beam.ptransform_fn +def IncrementCounter(value: beam.PCollection[int], counter_namespace: str, + counter_name: str) -> beam.PCollection[int]: + """Increments the given counter after summing the input values.""" + return (value + | "SumCounterValue" >> beam.CombineGlobally(sum) | "IncrementCounter" >> beam.Map( _IncrementCounter, counter_namespace=counter_namespace, counter_name=counter_name)) +@beam.ptransform_fn +def TrackRecordBatchBytes(dataset: beam.PCollection[pa.RecordBatch], + counter_namespace: str, + counter_name: str) -> beam.PCollection[int]: + """Gathers telemetry on input record batch.""" + return (dataset + | "GetRecordBatchSize" >> ExtractRecordBatchBytes() + | "IncrementCounter" >> IncrementCounter( + counter_namespace=counter_namespace, counter_name=counter_name)) + + def _IncrementTensorRepresentationCounters( tensor_representations: Dict[str, schema_pb2.TensorRepresentation], counter_namespace: str):