Skip to content

Commit

Permalink
Add Self Types to TensorBoardTests
Browse files Browse the repository at this point in the history
Differential Revision: D49423988

fbshipit-source-id: 701a2572384796ae69344e88ec970867a55a37e9
  • Loading branch information
Ethan Henderson authored and facebook-github-bot committed Sep 20, 2023
1 parent 1f26cf7 commit 664619e
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tests/utils/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import os
import tempfile
import unittest
Expand All @@ -19,7 +21,7 @@


class TensorBoardLoggerTest(unittest.TestCase):
def test_log(self) -> None:
def test_log(self: TensorBoardLoggerTest) -> None:
with tempfile.TemporaryDirectory() as log_dir:
logger = TensorBoardLogger(path=log_dir)
for i in range(5):
Expand All @@ -32,7 +34,7 @@ def test_log(self) -> None:
self.assertAlmostEqual(event.tensor_proto.float_val[0], float(i) ** 2)
self.assertEqual(event.step, i)

def test_log_dict(self) -> None:
def test_log_dict(self: TensorBoardLoggerTest) -> None:
with tempfile.TemporaryDirectory() as log_dir:
logger = TensorBoardLogger(path=log_dir)
metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)}
Expand All @@ -48,7 +50,7 @@ def test_log_dict(self) -> None:
)
self.assertEqual(tensor_tag.step, 1)

def test_log_text(self) -> None:
def test_log_text(self: TensorBoardLoggerTest) -> None:
with tempfile.TemporaryDirectory() as log_dir:
logger = TensorBoardLogger(path=log_dir)
for i in range(5):
Expand All @@ -64,7 +66,7 @@ def test_log_text(self) -> None:
)
self.assertEqual(test_text_event.step, i)

def test_log_rank_zero(self) -> None:
def test_log_rank_zero(self: TensorBoardLoggerTest) -> None:
with tempfile.TemporaryDirectory() as log_dir:
with patch.dict("os.environ", {"RANK": "1"}):
logger = TensorBoardLogger(path=log_dir)
Expand All @@ -90,6 +92,6 @@ def _test_distributed() -> None:
@unittest.skipUnless(
dist.is_available(), reason="Torch distributed is needed to run"
)
def test_multiple_workers(self) -> None:
def test_multiple_workers(self: TensorBoardLoggerTest) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._test_distributed)()

0 comments on commit 664619e

Please sign in to comment.