Skip to content

Commit

Permalink
Log Scalars to TensorBoard (#544)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #544

Reviewed By: ananthsub

Differential Revision: D49421324

fbshipit-source-id: b4f155711365840b2b10693266199cfa511893f0
  • Loading branch information
Ethan Henderson authored and facebook-github-bot committed Sep 20, 2023
1 parent 1e9c78e commit 7f07d33
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
30 changes: 29 additions & 1 deletion tests/utils/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import tempfile
import unittest
from unittest.mock import patch
from unittest.mock import Mock, patch

import torch.distributed.launcher as launcher
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
Expand Down Expand Up @@ -95,3 +95,31 @@ def _test_distributed() -> None:
def test_multiple_workers(self: TensorBoardLoggerTest) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._test_distributed)()

def test_add_scalars_call_is_correctly_passed_to_summary_writer(
self: TensorBoardLoggerTest,
) -> None:
with patch(
"torchtnt.utils.loggers.tensorboard.SummaryWriter"
) as mock_summary_writer_class:
mock_summary_writer = Mock()
mock_summary_writer_class.return_value = mock_summary_writer
logger = TensorBoardLogger(path="/tmp")
logger.log_scalars(
"tnt_metrics",
{
"x": 0,
"y": 1,
},
1,
2,
)
mock_summary_writer.add_scalars.assert_called_with(
main_tag="tnt_metrics",
tag_scalar_dict={
"x": 0,
"y": 1,
},
global_step=1,
walltime=2,
)
26 changes: 25 additions & 1 deletion torchtnt/utils/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import atexit
import logging
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, Union

import torch.distributed as dist

Expand Down Expand Up @@ -176,6 +176,30 @@ def log_audio(self: TensorBoardLogger, *args: Any, **kwargs: Any) -> None:
if writer:
writer.add_audio(*args, **kwargs)

def log_scalars(
self: TensorBoardLogger,
main_tag: str,
tag_scalar_dict: Dict[str, Union[float, int]],
global_step: Optional[int] = None,
walltime: Optional[float] = None,
) -> None:
"""Log multiple values to TensorBoard.
Args:
main_tag (string): Parent name for the tags
tag_scalar_dict (dict): dictionary of tag name and scalar value
global_step (int): global step value to record
walltime (float): Optional override default walltime (time.time())
Returns:
None
"""
if self._writer:
self._writer.add_scalars(
main_tag=main_tag,
tag_scalar_dict=tag_scalar_dict,
global_step=global_step,
walltime=walltime,
)

def flush(self: TensorBoardLogger) -> None:
"""Writes pending logs to disk."""

Expand Down

0 comments on commit 7f07d33

Please sign in to comment.