Skip to content

Commit

Permalink
Add always-on iteration timer to framework (#439)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #439

Add two counters, containing the iteration and time blocked on data for the last iteration. These are stored and recorded using a timer that does not synchronise with CUDA, and that only records these two values.

- iteration time is recorded in the training loop for all jobs.
- data time is recorded if the training loop does the data fetching, otherwise the user needs to instrument the logic which reads the data from the iterable.

Had to rework the approach slightly since the state does not seem the right place. With the changes, it is more natural to make it similar to progress, and store the info in a Stateful. This has the additional advantage that all values should be storable and be restored with the checkpoints.

Also rather than the last value, the last LOWER_BOUND values are stored. I hardcoded this in the torchtnt state, to not bloat parameters, since intuitively the last 1e4 values for each timer action here should be enough for monitoring purposes.

Reviewed By: daniellepintz, ananthsub

Differential Revision: D46853794

fbshipit-source-id: ac2e9f992f21ac66625f89d1c75d228397740e1f
  • Loading branch information
Miquel Jubert Hermoso authored and facebook-github-bot committed Aug 23, 2023
1 parent a690136 commit 1fab468
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 13 deletions.
48 changes: 44 additions & 4 deletions tests/framework/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Iterator, Tuple
from typing import Any, Iterator, Mapping, Tuple
from unittest.mock import MagicMock

import torch
Expand All @@ -16,7 +16,7 @@
from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.framework.unit import TrainUnit
from torchtnt.framework.unit import TrainUnit, TTrainUnit
from torchtnt.utils.timer import Timer


Expand All @@ -32,10 +32,13 @@ def test_train(self) -> None:
expected_steps_per_epoch = dataset_len / batch_size

my_unit = DummyTrainUnit(input_dim=input_dim)
initial_training_mode = my_unit.module.training

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
train(my_unit, dataloader, max_epochs=max_epochs)
train(
my_unit,
dataloader,
max_epochs=max_epochs,
)

self.assertEqual(my_unit.train_progress.num_epochs_completed, max_epochs)
self.assertEqual(my_unit.train_progress.num_steps_completed_in_epoch, 0)
Expand Down Expand Up @@ -138,6 +141,43 @@ def test_train_with_callback(self) -> None:
self.assertEqual(callback_mock.on_train_epoch_end.call_count, max_epochs)
self.assertEqual(callback_mock.on_train_end.call_count, 1)

def test_train_uses_iteration_timer(self) -> None:
"""
Test train records time in the iteration_timer
"""
input_dim = 2
dataset_len = 10
batch_size = 2
max_steps_per_epoch = 1
max_epochs = 1

my_unit = DummyTrainUnit(2)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)

def assertInTest(key: str, mapping: Mapping[str, Any]) -> None:
self.assertIn(key, mapping)

class CheckTimerUsedCallback(Callback):
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
assertInTest(
"data_wait_time",
state.train_state.iteration_timer.recorded_durations,
)
assertInTest(
"train_iteration_time",
state.train_state.iteration_timer.recorded_durations,
)

check_timer_callback = CheckTimerUsedCallback()

train(
my_unit,
dataloader,
max_epochs=max_epochs,
max_steps_per_epoch=max_steps_per_epoch,
callbacks=[check_timer_callback],
)

def test_train_data_iter_step(self) -> None:
class TrainIteratorUnit(TrainUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]):
def __init__(self, input_dim: int) -> None:
Expand Down
19 changes: 19 additions & 0 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.distributed.launcher as launcher
from torchtnt.utils.test_utils import get_pet_launch_config
from torchtnt.utils.timer import (
BoundedTimer,
FullSyncPeriodicTimer,
get_durations_histogram,
get_synced_durations_histogram,
Expand Down Expand Up @@ -44,6 +45,24 @@ def test_timer_verbose(self) -> None:
mock_info.assert_called_once()
self.assertTrue("Testing timer took" in mock_info.call_args.args[0])

def test_timer_context_manager_size_bound(self) -> None:
"""Test that timer keeps the number of samples within bounds"""
TEST_ACTION_STRING: str = "test action"
UPPER_BOUND: int = 10
LOWER_BOUND: int = 5
timer = BoundedTimer(lower_bound=LOWER_BOUND, upper_bound=UPPER_BOUND)
for i in range(1000):
with timer.time(TEST_ACTION_STRING):
pass
if i > LOWER_BOUND:
self.assertGreaterEqual(
len(timer.recorded_durations[TEST_ACTION_STRING]), LOWER_BOUND
)
self.assertLessEqual(
len(timer.recorded_durations[TEST_ACTION_STRING]),
UPPER_BOUND,
)

def test_timer_context_manager(self) -> None:
"""Test the context manager in the timer class"""

Expand Down
5 changes: 4 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,10 @@ def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData:
def train_step(
self, state: State, data: Iterator[TData]
) -> Tuple[torch.Tensor, Any]:
batch = self._get_next_batch(state, data)
# In auto unit they will not be exclusive since data fetching is done as
# part of the training step
with none_throws(state.train_state).iteration_timer.time("data_wait_time"):
batch = self._get_next_batch(state, data)

should_update_weights = (
self.train_progress.num_steps_completed_in_epoch + 1
Expand Down
12 changes: 11 additions & 1 deletion torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from enum import auto, Enum
from typing import Any, Iterable, Optional

from torchtnt.utils.timer import TimerProtocol
from torchtnt.utils.timer import BoundedTimer, TimerProtocol

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +88,9 @@ def __init__(
self._evaluate_every_n_epochs = evaluate_every_n_epochs

self._step_output: Any = None
self._iteration_timer = BoundedTimer(
cuda_sync=False, lower_bound=1_000, upper_bound=5_000
)

@property
def dataloader(self) -> Iterable[Any]:
Expand Down Expand Up @@ -124,6 +127,13 @@ def step_output(self) -> Any:
"""Output of the last step."""
return self._step_output

@property
def iteration_timer(self) -> TimerProtocol:
"""An always-on :class:`~torchtnt.utils.TimerProtocol` object which contains CPU timings (without synchronisation) of the iterations. For now
only populated during training.
"""
return self._iteration_timer


class State:
"""Parent State class which can contain up to 3 instances of PhaseState, for the 3 phases.
Expand Down
18 changes: 11 additions & 7 deletions torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,20 @@ def _train_epoch_impl(
try:
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with get_timing_context(state, "train.next(data_iter)"):
# If the iterator is passed to step, step is responsible for recording data_wait_time
with get_timing_context(
state, "train.next(data_iter)"
), train_state.iteration_timer.time("data_wait_time"):
step_input = next(data_iter)

callback_handler.on_train_step_start(state, train_unit)
train_state._step_output = train_unit.train_step(state, step_input)
train_unit.train_progress.increment_step()
callback_handler.on_train_step_end(state, train_unit)
with train_state.iteration_timer.time("train_iteration_time"):
callback_handler.on_train_step_start(state, train_unit)
train_state._step_output = train_unit.train_step(state, step_input)
train_unit.train_progress.increment_step()
callback_handler.on_train_step_end(state, train_unit)

# clear step_output to avoid retaining extra memory
train_state._step_output = None
# clear step_output to avoid retaining extra memory
train_state._step_output = None

if (
evaluate_every_n_steps
Expand Down
42 changes: 42 additions & 0 deletions torchtnt/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from contextlib import contextmanager
from time import perf_counter
from typing import (
Any,
Dict,
Generator,
List,
Expand Down Expand Up @@ -107,6 +108,9 @@ def __init__(
Args:
cuda_sync: whether to call torch.cuda.synchronize() before and after timing. Defaults to True if CUDA is available.
verbose: whether to enable verbose logging.
size_bounds: defines the range of samples that should be kept in the timer. The lower bound should be smaller than
the upper bound. When the number of samples reaches the upper bound, the oldest (upper-lower) bound samples will
be removed. This range is applied per action.
Note:
Enabling cuda_sync will incur a performance hit, but will ensure accurate timings on GPUs.
Expand Down Expand Up @@ -156,6 +160,44 @@ def reset(self) -> None:
self.recorded_durations = defaultdict(list)


class BoundedTimer(Timer):
"""
A Timer class which implements TimerProtocol and stores timings in a dictionary `recorded_durations`.
Same behavior as timer, but with the addition of size_bounds = (lower, upper)
Args:
...
size_bounds: defines the range of samples that should be kept in the timer. The lower bound should be smaller than
the upper bound. When the number of samples reaches the upper bound, the oldest (upper-lower) bound samples will
be removed. This range is applied per action.
"""

def __init__(self, lower_bound: int, upper_bound: int, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert lower_bound > 0
assert lower_bound < upper_bound
self.lower_bound = lower_bound
self.upper_bound = upper_bound

@contextmanager
def time(
self,
action_name: str,
) -> Generator[None, None, None]:
with super().time(action_name):
yield
self._apply_bounds(action_name)

def _apply_bounds(self, action_name: str) -> None:
# Keep 'lower_bound' most recent samples, if at or over upper bound
n_samples: int = len(self.recorded_durations[action_name])
if self.upper_bound <= n_samples:
self.recorded_durations[action_name] = list(
self.recorded_durations[action_name][-self.lower_bound :]
)


def _get_total_time(timer: TimerProtocol) -> float:
total_time = 0.0
for _, durations in timer.recorded_durations.items():
Expand Down

0 comments on commit 1fab468

Please sign in to comment.