Skip to content

Commit

Permalink
Measure times
Browse files Browse the repository at this point in the history
Summary:
Add two counters, containing the iteration and time blocked on data for the last iteration. These are stored on the state, not sure if the unit would be a better place.

They can be accessed at the start of the next iteration. This is done because we want to capture the time of the whole iteration, including callbacks, so it cannot be accessed by callbacks.

Another possibility would be to keep a list of all the times and data times, not sure if that would be preferrable.

Differential Revision: D46853794

fbshipit-source-id: 094c97d2936e33ae3feb9daaeb0529457b08477d
  • Loading branch information
Miquel Jubert Hermoso authored and facebook-github-bot committed Jul 3, 2023
1 parent b39c53e commit ea251ca
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 23 deletions.
13 changes: 13 additions & 0 deletions tests/framework/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def test_train(self) -> None:

self.assertEqual(my_unit.module.training, initial_training_mode)
self.assertEqual(state.entry_point, EntryPoint.TRAIN)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_max_steps_per_epoch(self) -> None:
"""
Expand Down Expand Up @@ -84,6 +86,8 @@ def test_train_max_steps_per_epoch(self) -> None:
self.assertEqual(state.train_state.step_output, None)

self.assertEqual(my_unit.module.training, initial_training_mode)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_stop(self) -> None:
"""
Expand Down Expand Up @@ -115,6 +119,8 @@ def test_train_stop(self) -> None:
my_unit.steps_processed, state.train_state.progress.num_steps_completed
)
self.assertEqual(my_unit.steps_processed, steps_before_stopping)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_with_callback(self) -> None:
"""
Expand Down Expand Up @@ -147,6 +153,8 @@ def test_train_with_callback(self) -> None:
)
self.assertEqual(callback_mock.on_train_epoch_end.call_count, max_epochs)
self.assertEqual(callback_mock.on_train_end.call_count, 1)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_data_iter_step(self) -> None:
class TrainIteratorUnit(TrainUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]):
Expand Down Expand Up @@ -189,6 +197,9 @@ def train_step(
self.assertEqual(state.train_state.step_output, None)

self.assertEqual(my_unit.module.training, initial_training_mode)
self.assertNotEqual(state.train_state.last_iteration_time, None)
# The unit being used does not log the data_wait_time
self.assertEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_max_steps(self) -> None:
max_steps = 3
Expand Down Expand Up @@ -226,6 +237,8 @@ def test_train_max_steps(self) -> None:
self.assertEqual(
my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch
)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_auto_timing(self) -> None:
"""
Expand Down
7 changes: 5 additions & 2 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from torchtnt.framework.utils import (
_get_timing_context,
_is_fsdp_module,
_log_time_waiting_for_data,
get_current_progress,
)
from torchtnt.utils import (
Expand Down Expand Up @@ -653,8 +654,10 @@ def train_step(
self, state: State, data: Iterator[TData]
) -> Tuple[torch.Tensor, Any]:
train_state = none_throws(state.train_state)

batch = self._get_next_batch(state, data)
# __next__ is already decorated and will log this, this call will over-write it
# to set an accurate number according to the pre-fetching
with _log_time_waiting_for_data(state):
batch = self._get_next_batch(state, data)

should_update_weights = (
train_state.progress.num_steps_completed_in_epoch + 1
Expand Down
14 changes: 14 additions & 0 deletions torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(
self._evaluate_every_n_epochs = evaluate_every_n_epochs

self._step_output: Any = None
self._last_iteration_time: Optional[float] = None
self._last_iteration_data_wait_time: Optional[float] = None

@property
def dataloader(self) -> Iterable[Any]:
Expand Down Expand Up @@ -132,6 +134,18 @@ def step_output(self) -> Any:
"""Output of the last step."""
return self._step_output

@property
def last_iteration_time(self) -> Any:
"""Time it took to run through the last iteration."""
return self._last_iteration_time

@property
def last_iteration_data_wait_time(self) -> Any:
"""Time the process was waiting for data from the dataloader. Usually equivalent
to timing __next__(self):
"""
return self._last_iteration_data_wait_time


class State:
"""Parent State class which can contain up to 3 instances of PhaseState, for the 3 phases.
Expand Down
41 changes: 23 additions & 18 deletions torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
_get_timing_context,
_is_done,
_is_epoch_done,
_log_iteration_time,
_log_time_waiting_for_data,
_maybe_set_distributed_sampler_epoch,
_reset_module_training_mode,
_run_callback_fn,
Expand Down Expand Up @@ -231,26 +233,29 @@ def _train_epoch_impl(
)
):
try:
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with _get_timing_context(state, "train.next(data_iter)"):
step_input = next(data_iter)

_run_callback_fn(callbacks, "on_train_step_start", state, train_unit)

with _get_timing_context(
state,
f"{train_unit.__class__.__name__}.train_step",
skip_timer=is_auto_unit,
# skip timer if train_unit is a subclass of AutoUnit because there is additional timing in the AutoUnit, and all timing should be mutually exclusive
):
train_state._step_output = train_unit.train_step(state, step_input)
with _log_iteration_time(state):
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with _get_timing_context(
state, "train.next(data_iter)"
), _log_time_waiting_for_data(state):
step_input = next(data_iter)

_run_callback_fn(callbacks, "on_train_step_start", state, train_unit)

with _get_timing_context(
state,
f"{train_unit.__class__.__name__}.train_step",
skip_timer=is_auto_unit,
# skip timer if train_unit is a subclass of AutoUnit because there is additional timing in the AutoUnit, and all timing should be mutually exclusive
):
train_state._step_output = train_unit.train_step(state, step_input)

train_state.progress.increment_step()
_run_callback_fn(callbacks, "on_train_step_end", state, train_unit)
train_state.progress.increment_step()
_run_callback_fn(callbacks, "on_train_step_end", state, train_unit)

# clear step_output to avoid retaining extra memory
train_state._step_output = None
# clear step_output to avoid retaining extra memory
train_state._step_output = None

if (
evaluate_every_n_steps
Expand Down
57 changes: 54 additions & 3 deletions torchtnt/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,26 @@
import contextlib
import inspect
import logging

import time
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)

import torch
import torch.nn as nn
import typing_extensions
from pyre_extensions import none_throws
from torch.distributed.fsdp import (
FullyShardedDataParallel,
FullyShardedDataParallel as FSDP,
Expand All @@ -26,15 +40,15 @@
from torch.distributed.fsdp._common_utils import _FSDPState


from pyre_extensions import none_throws
from torchtnt.framework.callback import Callback
from torchtnt.framework.state import ActivePhase, EntryPoint, State
from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.framework.unit import AppStateMixin
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.progress import Progress

_logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T")

# Helper functions common across the loops
def _is_done(
Expand Down Expand Up @@ -106,6 +120,43 @@ def _get_timing_context(state: State, event_name: str, skip_timer: bool = False)
yield (timer_context, profiler_context)


@contextmanager
def _log_time_waiting_for_data(state: State) -> Iterator[None]:
"""Returns a context manager to use around the call to __next__, which will
log the time it takes to receive the data"""
try:
start_time: float = time.perf_counter()
yield
finally:
total_time: float = time.perf_counter() - start_time
_get_active_state(state)._last_iteration_data_wait_time = total_time


@contextmanager
def _log_iteration_time(state: State) -> Iterator[None]:
"""Returns a context manager to use around the call to __next__, which will
log the time it takes to receive the data"""
try:
start_time: float = time.perf_counter()
yield
finally:
total_time: float = time.perf_counter() - start_time
_get_active_state(state)._last_iteration_time = total_time


def _get_active_state(state: State) -> PhaseState:
if state.active_phase == ActivePhase.TRAIN:
return none_throws(state.train_state)
elif state.active_phase == ActivePhase.EVALUATE:
return none_throws(state.eval_state)
elif state.active_phase == ActivePhase.PREDICT:
return none_throws(state.predict_state)
else:
raise ValueError(
f"State is not in one of the three active states, {state.active_phase}"
)


def _run_callback_fn(
callbacks: List[Callback],
fn_name: str,
Expand Down

0 comments on commit ea251ca

Please sign in to comment.