From ea251caf1045ac826378a208909398c14d7645d4 Mon Sep 17 00:00:00 2001 From: Miquel Jubert Hermoso Date: Mon, 3 Jul 2023 08:26:44 -0700 Subject: [PATCH] Measure times Summary: Add two counters, containing the iteration and time blocked on data for the last iteration. These are stored on the state, not sure if the unit would be a better place. They can be accessed at the start of the next iteration. This is done because we want to capture the time of the whole iteration, including callbacks, so it cannot be accessed by callbacks. Another possibility would be to keep a list of all the times and data times, not sure if that would be preferrable. Differential Revision: D46853794 fbshipit-source-id: 094c97d2936e33ae3feb9daaeb0529457b08477d --- tests/framework/test_train.py | 13 ++++++++ torchtnt/framework/auto_unit.py | 7 ++-- torchtnt/framework/state.py | 14 ++++++++ torchtnt/framework/train.py | 41 +++++++++++++----------- torchtnt/framework/utils.py | 57 +++++++++++++++++++++++++++++++-- 5 files changed, 109 insertions(+), 23 deletions(-) diff --git a/tests/framework/test_train.py b/tests/framework/test_train.py index 1f8cb405a7..ba50156f6f 100644 --- a/tests/framework/test_train.py +++ b/tests/framework/test_train.py @@ -48,6 +48,8 @@ def test_train(self) -> None: self.assertEqual(my_unit.module.training, initial_training_mode) self.assertEqual(state.entry_point, EntryPoint.TRAIN) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_max_steps_per_epoch(self) -> None: """ @@ -84,6 +86,8 @@ def test_train_max_steps_per_epoch(self) -> None: self.assertEqual(state.train_state.step_output, None) self.assertEqual(my_unit.module.training, initial_training_mode) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_stop(self) -> None: """ @@ -115,6 +119,8 @@ def test_train_stop(self) -> None: my_unit.steps_processed, state.train_state.progress.num_steps_completed ) self.assertEqual(my_unit.steps_processed, steps_before_stopping) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_with_callback(self) -> None: """ @@ -147,6 +153,8 @@ def test_train_with_callback(self) -> None: ) self.assertEqual(callback_mock.on_train_epoch_end.call_count, max_epochs) self.assertEqual(callback_mock.on_train_end.call_count, 1) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_data_iter_step(self) -> None: class TrainIteratorUnit(TrainUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]): @@ -189,6 +197,9 @@ def train_step( self.assertEqual(state.train_state.step_output, None) self.assertEqual(my_unit.module.training, initial_training_mode) + self.assertNotEqual(state.train_state.last_iteration_time, None) + # The unit being used does not log the data_wait_time + self.assertEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_max_steps(self) -> None: max_steps = 3 @@ -226,6 +237,8 @@ def test_train_max_steps(self) -> None: self.assertEqual( my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch ) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_auto_timing(self) -> None: """ diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 6ae5875579..63feac28dc 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -41,6 +41,7 @@ from torchtnt.framework.utils import ( _get_timing_context, _is_fsdp_module, + _log_time_waiting_for_data, get_current_progress, ) from torchtnt.utils import ( @@ -653,8 +654,10 @@ def train_step( self, state: State, data: Iterator[TData] ) -> Tuple[torch.Tensor, Any]: train_state = none_throws(state.train_state) - - batch = self._get_next_batch(state, data) + # __next__ is already decorated and will log this, this call will over-write it + # to set an accurate number according to the pre-fetching + with _log_time_waiting_for_data(state): + batch = self._get_next_batch(state, data) should_update_weights = ( train_state.progress.num_steps_completed_in_epoch + 1 diff --git a/torchtnt/framework/state.py b/torchtnt/framework/state.py index 7afcce8258..381021129c 100644 --- a/torchtnt/framework/state.py +++ b/torchtnt/framework/state.py @@ -91,6 +91,8 @@ def __init__( self._evaluate_every_n_epochs = evaluate_every_n_epochs self._step_output: Any = None + self._last_iteration_time: Optional[float] = None + self._last_iteration_data_wait_time: Optional[float] = None @property def dataloader(self) -> Iterable[Any]: @@ -132,6 +134,18 @@ def step_output(self) -> Any: """Output of the last step.""" return self._step_output + @property + def last_iteration_time(self) -> Any: + """Time it took to run through the last iteration.""" + return self._last_iteration_time + + @property + def last_iteration_data_wait_time(self) -> Any: + """Time the process was waiting for data from the dataloader. Usually equivalent + to timing __next__(self): + """ + return self._last_iteration_data_wait_time + class State: """Parent State class which can contain up to 3 instances of PhaseState, for the 3 phases. diff --git a/torchtnt/framework/train.py b/torchtnt/framework/train.py index dd96350a60..e146588181 100644 --- a/torchtnt/framework/train.py +++ b/torchtnt/framework/train.py @@ -18,6 +18,8 @@ _get_timing_context, _is_done, _is_epoch_done, + _log_iteration_time, + _log_time_waiting_for_data, _maybe_set_distributed_sampler_epoch, _reset_module_training_mode, _run_callback_fn, @@ -231,26 +233,29 @@ def _train_epoch_impl( ) ): try: - if not pass_data_iter_to_step: - # get the next batch from the data iterator - with _get_timing_context(state, "train.next(data_iter)"): - step_input = next(data_iter) - - _run_callback_fn(callbacks, "on_train_step_start", state, train_unit) - - with _get_timing_context( - state, - f"{train_unit.__class__.__name__}.train_step", - skip_timer=is_auto_unit, - # skip timer if train_unit is a subclass of AutoUnit because there is additional timing in the AutoUnit, and all timing should be mutually exclusive - ): - train_state._step_output = train_unit.train_step(state, step_input) + with _log_iteration_time(state): + if not pass_data_iter_to_step: + # get the next batch from the data iterator + with _get_timing_context( + state, "train.next(data_iter)" + ), _log_time_waiting_for_data(state): + step_input = next(data_iter) + + _run_callback_fn(callbacks, "on_train_step_start", state, train_unit) + + with _get_timing_context( + state, + f"{train_unit.__class__.__name__}.train_step", + skip_timer=is_auto_unit, + # skip timer if train_unit is a subclass of AutoUnit because there is additional timing in the AutoUnit, and all timing should be mutually exclusive + ): + train_state._step_output = train_unit.train_step(state, step_input) - train_state.progress.increment_step() - _run_callback_fn(callbacks, "on_train_step_end", state, train_unit) + train_state.progress.increment_step() + _run_callback_fn(callbacks, "on_train_step_end", state, train_unit) - # clear step_output to avoid retaining extra memory - train_state._step_output = None + # clear step_output to avoid retaining extra memory + train_state._step_output = None if ( evaluate_every_n_steps diff --git a/torchtnt/framework/utils.py b/torchtnt/framework/utils.py index 913451c221..17be3cb47c 100644 --- a/torchtnt/framework/utils.py +++ b/torchtnt/framework/utils.py @@ -8,12 +8,26 @@ import contextlib import inspect import logging + +import time from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) import torch import torch.nn as nn import typing_extensions +from pyre_extensions import none_throws from torch.distributed.fsdp import ( FullyShardedDataParallel, FullyShardedDataParallel as FSDP, @@ -26,15 +40,15 @@ from torch.distributed.fsdp._common_utils import _FSDPState -from pyre_extensions import none_throws from torchtnt.framework.callback import Callback -from torchtnt.framework.state import ActivePhase, EntryPoint, State +from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.framework.unit import AppStateMixin from torchtnt.utils.lr_scheduler import TLRScheduler from torchtnt.utils.progress import Progress _logger: logging.Logger = logging.getLogger(__name__) +T = TypeVar("T") # Helper functions common across the loops def _is_done( @@ -106,6 +120,43 @@ def _get_timing_context(state: State, event_name: str, skip_timer: bool = False) yield (timer_context, profiler_context) +@contextmanager +def _log_time_waiting_for_data(state: State) -> Iterator[None]: + """Returns a context manager to use around the call to __next__, which will + log the time it takes to receive the data""" + try: + start_time: float = time.perf_counter() + yield + finally: + total_time: float = time.perf_counter() - start_time + _get_active_state(state)._last_iteration_data_wait_time = total_time + + +@contextmanager +def _log_iteration_time(state: State) -> Iterator[None]: + """Returns a context manager to use around the call to __next__, which will + log the time it takes to receive the data""" + try: + start_time: float = time.perf_counter() + yield + finally: + total_time: float = time.perf_counter() - start_time + _get_active_state(state)._last_iteration_time = total_time + + +def _get_active_state(state: State) -> PhaseState: + if state.active_phase == ActivePhase.TRAIN: + return none_throws(state.train_state) + elif state.active_phase == ActivePhase.EVALUATE: + return none_throws(state.eval_state) + elif state.active_phase == ActivePhase.PREDICT: + return none_throws(state.predict_state) + else: + raise ValueError( + f"State is not in one of the three active states, {state.active_phase}" + ) + + def _run_callback_fn( callbacks: List[Callback], fn_name: str,