Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Measure times #439

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions tests/framework/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Iterator, Tuple
from typing import Any, Iterator, Mapping, Tuple
from unittest.mock import MagicMock

import torch
Expand All @@ -16,7 +16,7 @@
from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.framework.unit import TrainUnit
from torchtnt.framework.unit import TrainUnit, TTrainUnit
from torchtnt.utils.timer import Timer


Expand All @@ -32,10 +32,13 @@ def test_train(self) -> None:
expected_steps_per_epoch = dataset_len / batch_size

my_unit = DummyTrainUnit(input_dim=input_dim)
initial_training_mode = my_unit.module.training

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
train(my_unit, dataloader, max_epochs=max_epochs)
train(
my_unit,
dataloader,
max_epochs=max_epochs,
)

self.assertEqual(my_unit.train_progress.num_epochs_completed, max_epochs)
self.assertEqual(my_unit.train_progress.num_steps_completed_in_epoch, 0)
Expand Down Expand Up @@ -138,6 +141,43 @@ def test_train_with_callback(self) -> None:
self.assertEqual(callback_mock.on_train_epoch_end.call_count, max_epochs)
self.assertEqual(callback_mock.on_train_end.call_count, 1)

def test_train_uses_iteration_timer(self) -> None:
"""
Test train records time in the iteration_timer
"""
input_dim = 2
dataset_len = 10
batch_size = 2
max_steps_per_epoch = 1
max_epochs = 1

my_unit = DummyTrainUnit(2)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)

def assertInTest(key: str, mapping: Mapping[str, Any]) -> None:
self.assertIn(key, mapping)

class CheckTimerUsedCallback(Callback):
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
assertInTest(
"data_wait_time",
state.train_state.iteration_timer.recorded_durations,
)
assertInTest(
"train_iteration_time",
state.train_state.iteration_timer.recorded_durations,
)

check_timer_callback = CheckTimerUsedCallback()

train(
my_unit,
dataloader,
max_epochs=max_epochs,
max_steps_per_epoch=max_steps_per_epoch,
callbacks=[check_timer_callback],
)

def test_train_data_iter_step(self) -> None:
class TrainIteratorUnit(TrainUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]):
def __init__(self, input_dim: int) -> None:
Expand Down
19 changes: 19 additions & 0 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.distributed.launcher as launcher
from torchtnt.utils.test_utils import get_pet_launch_config
from torchtnt.utils.timer import (
BoundedTimer,
FullSyncPeriodicTimer,
get_durations_histogram,
get_synced_durations_histogram,
Expand Down Expand Up @@ -44,6 +45,24 @@ def test_timer_verbose(self) -> None:
mock_info.assert_called_once()
self.assertTrue("Testing timer took" in mock_info.call_args.args[0])

def test_timer_context_manager_size_bound(self) -> None:
"""Test that timer keeps the number of samples within bounds"""
TEST_ACTION_STRING: str = "test action"
UPPER_BOUND: int = 10
LOWER_BOUND: int = 5
timer = BoundedTimer(lower_bound=LOWER_BOUND, upper_bound=UPPER_BOUND)
for i in range(1000):
with timer.time(TEST_ACTION_STRING):
pass
if i > LOWER_BOUND:
self.assertGreaterEqual(
len(timer.recorded_durations[TEST_ACTION_STRING]), LOWER_BOUND
)
self.assertLessEqual(
len(timer.recorded_durations[TEST_ACTION_STRING]),
UPPER_BOUND,
)

def test_timer_context_manager(self) -> None:
"""Test the context manager in the timer class"""

Expand Down
5 changes: 4 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,10 @@ def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData:
def train_step(
self, state: State, data: Iterator[TData]
) -> Tuple[torch.Tensor, Any]:
batch = self._get_next_batch(state, data)
# In auto unit they will not be exclusive since data fetching is done as
# part of the training step
with none_throws(state.train_state).iteration_timer.time("data_wait_time"):
batch = self._get_next_batch(state, data)

should_update_weights = (
self.train_progress.num_steps_completed_in_epoch + 1
Expand Down
12 changes: 11 additions & 1 deletion torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from enum import auto, Enum
from typing import Any, Iterable, Optional

from torchtnt.utils.timer import TimerProtocol
from torchtnt.utils.timer import BoundedTimer, TimerProtocol

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +88,9 @@ def __init__(
self._evaluate_every_n_epochs = evaluate_every_n_epochs

self._step_output: Any = None
self._iteration_timer = BoundedTimer(
cuda_sync=False, lower_bound=1_000, upper_bound=5_000
)

@property
def dataloader(self) -> Iterable[Any]:
Expand Down Expand Up @@ -124,6 +127,13 @@ def step_output(self) -> Any:
"""Output of the last step."""
return self._step_output

@property
def iteration_timer(self) -> TimerProtocol:
"""An always-on :class:`~torchtnt.utils.TimerProtocol` object which contains CPU timings (without synchronisation) of the iterations. For now
only populated during training.
"""
return self._iteration_timer


class State:
"""Parent State class which can contain up to 3 instances of PhaseState, for the 3 phases.
Expand Down
18 changes: 11 additions & 7 deletions torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,20 @@ def _train_epoch_impl(
try:
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with get_timing_context(state, "train.next(data_iter)"):
# If the iterator is passed to step, step is responsible for recording data_wait_time
with get_timing_context(
state, "train.next(data_iter)"
), train_state.iteration_timer.time("data_wait_time"):
step_input = next(data_iter)

callback_handler.on_train_step_start(state, train_unit)
train_state._step_output = train_unit.train_step(state, step_input)
train_unit.train_progress.increment_step()
callback_handler.on_train_step_end(state, train_unit)
with train_state.iteration_timer.time("train_iteration_time"):
callback_handler.on_train_step_start(state, train_unit)
train_state._step_output = train_unit.train_step(state, step_input)
train_unit.train_progress.increment_step()
callback_handler.on_train_step_end(state, train_unit)

# clear step_output to avoid retaining extra memory
train_state._step_output = None
# clear step_output to avoid retaining extra memory
train_state._step_output = None

if (
evaluate_every_n_steps
Expand Down
42 changes: 42 additions & 0 deletions torchtnt/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from contextlib import contextmanager
from time import perf_counter
from typing import (
Any,
Dict,
Generator,
List,
Expand Down Expand Up @@ -107,6 +108,9 @@ def __init__(
Args:
cuda_sync: whether to call torch.cuda.synchronize() before and after timing. Defaults to True if CUDA is available.
verbose: whether to enable verbose logging.
size_bounds: defines the range of samples that should be kept in the timer. The lower bound should be smaller than
the upper bound. When the number of samples reaches the upper bound, the oldest (upper-lower) bound samples will
be removed. This range is applied per action.

Note:
Enabling cuda_sync will incur a performance hit, but will ensure accurate timings on GPUs.
Expand Down Expand Up @@ -156,6 +160,44 @@ def reset(self) -> None:
self.recorded_durations = defaultdict(list)


class BoundedTimer(Timer):
"""
A Timer class which implements TimerProtocol and stores timings in a dictionary `recorded_durations`.

Same behavior as timer, but with the addition of size_bounds = (lower, upper)

Args:
...
size_bounds: defines the range of samples that should be kept in the timer. The lower bound should be smaller than
the upper bound. When the number of samples reaches the upper bound, the oldest (upper-lower) bound samples will
be removed. This range is applied per action.
"""

def __init__(self, lower_bound: int, upper_bound: int, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert lower_bound > 0
assert lower_bound < upper_bound
self.lower_bound = lower_bound
self.upper_bound = upper_bound

@contextmanager
def time(
self,
action_name: str,
) -> Generator[None, None, None]:
with super().time(action_name):
yield
self._apply_bounds(action_name)

def _apply_bounds(self, action_name: str) -> None:
# Keep 'lower_bound' most recent samples, if at or over upper bound
n_samples: int = len(self.recorded_durations[action_name])
if self.upper_bound <= n_samples:
self.recorded_durations[action_name] = list(
self.recorded_durations[action_name][-self.lower_bound :]
)


def _get_total_time(timer: TimerProtocol) -> float:
total_time = 0.0
for _, durations in timer.recorded_durations.items():
Expand Down