From 32f5cf14665dbf29a0a6dba54bc002b5bcf73079 Mon Sep 17 00:00:00 2001 From: Ethan Henderson Date: Tue, 19 Sep 2023 11:09:05 -0700 Subject: [PATCH] Log Scalars to TensorBoard Reviewed By: ananthsub Differential Revision: D49421324 fbshipit-source-id: 5a5b5f8f0a60219caa5b40e9451ef379a7346cd5 --- tests/utils/loggers/test_tensorboard.py | 31 ++++++++++++++++++++++++- torchtnt/utils/loggers/tensorboard.py | 26 ++++++++++++++++++++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/tests/utils/loggers/test_tensorboard.py b/tests/utils/loggers/test_tensorboard.py index c23b0df550..73a34df101 100644 --- a/tests/utils/loggers/test_tensorboard.py +++ b/tests/utils/loggers/test_tensorboard.py @@ -6,10 +6,11 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations + import os import tempfile import unittest -from unittest.mock import patch +from unittest.mock import Mock, patch import torch.distributed.launcher as launcher from tensorboard.backend.event_processing.event_accumulator import EventAccumulator @@ -94,3 +95,31 @@ def _test_distributed() -> None: def test_multiple_workers(self: TensorBoardLoggerTest) -> None: config = get_pet_launch_config(2) launcher.elastic_launch(config, entrypoint=self._test_distributed)() + + def test_add_scalars_call_is_correctly_passed_to_summary_writer( + self: TensorBoardLoggerTest, + ) -> None: + with patch( + "torchtnt.utils.loggers.tensorboard.SummaryWriter" + ) as mock_summary_writer_class: + mock_summary_writer = Mock() + mock_summary_writer_class.return_value = mock_summary_writer + logger = TensorBoardLogger(path="/tmp") + logger.log_scalars( + "tnt_metrics", + { + "x": 0, + "y": 1, + }, + 1, + 2, + ) + mock_summary_writer.add_scalars.assert_called_with( + main_tag="tnt_metrics", + tag_scalar_dict={ + "x": 0, + "y": 1, + }, + global_step=1, + walltime=2, + ) diff --git a/torchtnt/utils/loggers/tensorboard.py b/torchtnt/utils/loggers/tensorboard.py index b177ebf00e..ea8dea0edc 100644 --- a/torchtnt/utils/loggers/tensorboard.py +++ b/torchtnt/utils/loggers/tensorboard.py @@ -9,7 +9,7 @@ import atexit import logging -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Union import torch.distributed as dist @@ -175,6 +175,30 @@ def log_audio(self: TensorBoardLogger, *args: Any, **kwargs: Any) -> None: if writer: writer.add_audio(*args, **kwargs) + def log_scalars( + self: TensorBoardLogger, + main_tag: str, + tag_scalar_dict: Dict[str, Union[float, int]], + global_step: Optional[int] = None, + walltime: Optional[float] = None, + ) -> None: + """Log multiple values to TensorBoard. + Args: + main_tag (string): Parent name for the tags + tag_scalar_dict (dict): dictionary of tag name and scalar value + global_step (int): global step value to record + walltime (float): Optional override default walltime (time.time()) + Returns: + None + """ + if self._writer: + self._writer.add_scalars( + main_tag=main_tag, + tag_scalar_dict=tag_scalar_dict, + global_step=global_step, + walltime=walltime, + ) + def flush(self: TensorBoardLogger) -> None: """Writes pending logs to disk."""