diff --git a/tests/framework/test_train.py b/tests/framework/test_train.py index 46af93b589..d2121dfa50 100644 --- a/tests/framework/test_train.py +++ b/tests/framework/test_train.py @@ -49,6 +49,8 @@ def test_train(self) -> None: self.assertEqual(my_unit.module.training, initial_training_mode) self.assertEqual(state.entry_point, EntryPoint.TRAIN) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_max_steps_per_epoch(self) -> None: """ @@ -85,6 +87,8 @@ def test_train_max_steps_per_epoch(self) -> None: self.assertEqual(state.train_state.step_output, None) self.assertEqual(my_unit.module.training, initial_training_mode) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_stop(self) -> None: """ @@ -116,6 +120,8 @@ def test_train_stop(self) -> None: my_unit.steps_processed, state.train_state.progress.num_steps_completed ) self.assertEqual(my_unit.steps_processed, steps_before_stopping) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_with_callback(self) -> None: """ @@ -148,6 +154,8 @@ def test_train_with_callback(self) -> None: ) self.assertEqual(callback_mock.on_train_epoch_end.call_count, max_epochs) self.assertEqual(callback_mock.on_train_end.call_count, 1) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_data_iter_step(self) -> None: class TrainIteratorUnit(TrainUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]): @@ -190,6 +198,8 @@ def train_step( self.assertEqual(state.train_state.step_output, None) self.assertEqual(my_unit.module.training, initial_training_mode) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_max_steps(self) -> None: max_steps = 3 @@ -227,6 +237,8 @@ def test_train_max_steps(self) -> None: self.assertEqual( my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch ) + self.assertNotEqual(state.train_state.last_iteration_time, None) + self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None) def test_train_auto_timing(self) -> None: """ diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index e3c8d9f889..2175a1fced 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -30,6 +30,7 @@ from torchtnt.framework.utils import ( _get_timing_context, _is_fsdp_module, + _log_time_waiting_for_data, get_current_progress, ) from torchtnt.utils.device import copy_data_to_device, record_data_in_stream @@ -628,8 +629,10 @@ def train_step( self, state: State, data: Iterator[TData] ) -> Tuple[torch.Tensor, Any]: train_state = none_throws(state.train_state) - - batch = self._get_next_batch(state, data) + # __next__ is already decorated and will log this, this call will over-write it + # to set an accurate number according to the pre-fetching + with _log_time_waiting_for_data(state): + batch = self._get_next_batch(state, data) should_update_weights = ( train_state.progress.num_steps_completed_in_epoch + 1 diff --git a/torchtnt/framework/state.py b/torchtnt/framework/state.py index 7afcce8258..381021129c 100644 --- a/torchtnt/framework/state.py +++ b/torchtnt/framework/state.py @@ -91,6 +91,8 @@ def __init__( self._evaluate_every_n_epochs = evaluate_every_n_epochs self._step_output: Any = None + self._last_iteration_time: Optional[float] = None + self._last_iteration_data_wait_time: Optional[float] = None @property def dataloader(self) -> Iterable[Any]: @@ -132,6 +134,18 @@ def step_output(self) -> Any: """Output of the last step.""" return self._step_output + @property + def last_iteration_time(self) -> Any: + """Time it took to run through the last iteration.""" + return self._last_iteration_time + + @property + def last_iteration_data_wait_time(self) -> Any: + """Time the process was waiting for data from the dataloader. Usually equivalent + to timing __next__(self): + """ + return self._last_iteration_data_wait_time + class State: """Parent State class which can contain up to 3 instances of PhaseState, for the 3 phases. diff --git a/torchtnt/framework/train.py b/torchtnt/framework/train.py index e4c81c48a4..279303ecb8 100644 --- a/torchtnt/framework/train.py +++ b/torchtnt/framework/train.py @@ -19,6 +19,8 @@ _get_timing_context, _is_done, _is_epoch_done, + _log_iteration_time, + _log_time_waiting_for_data, _maybe_set_distributed_sampler_epoch, _reset_module_training_mode, _set_module_training_mode, @@ -231,40 +233,43 @@ def _train_epoch_impl( ) ): try: - if not pass_data_iter_to_step: - # get the next batch from the data iterator - with _get_timing_context(state, "train.next(data_iter)"): - step_input = next(data_iter) - - callback_handler.on_train_step_start(state, train_unit) - - with _get_timing_context( - state, - f"{train_unit.__class__.__name__}.train_step", - skip_timer=is_auto_unit, - # skip timer if train_unit is a subclass of AutoUnit because there is additional timing in the AutoUnit, and all timing should be mutually exclusive - ): - train_state._step_output = train_unit.train_step(state, step_input) - - train_state.progress.increment_step() - callback_handler.on_train_step_end(state, train_unit) - - # clear step_output to avoid retaining extra memory - train_state._step_output = None - - if ( - evaluate_every_n_steps - and train_state.progress.num_steps_completed % evaluate_every_n_steps - == 0 - ): - _evaluate_impl( + with _log_time_waiting_for_data(state): + if not pass_data_iter_to_step: + # get the next batch from the data iterator + with _get_timing_context(state, "train.next(data_iter)"): + step_input = next(data_iter) + + with _log_iteration_time(state): + callback_handler.on_train_step_start(state, train_unit) + + with _get_timing_context( state, - # pyre-ignore: Incompatible parameter type [6] - train_unit, - callback_handler, - ) - logger.info("Finished evaluation. Resuming training epoch") - state._active_phase = ActivePhase.TRAIN + f"{train_unit.__class__.__name__}.train_step", + skip_timer=is_auto_unit, + # skip timer if train_unit is a subclass of AutoUnit because there is additional timing in the AutoUnit, and all timing should be mutually exclusive + ): + train_state._step_output = train_unit.train_step(state, step_input) + + train_state.progress.increment_step() + callback_handler.on_train_step_end(state, train_unit) + + # clear step_output to avoid retaining extra memory + train_state._step_output = None + + if ( + evaluate_every_n_steps + and train_state.progress.num_steps_completed + % evaluate_every_n_steps + == 0 + ): + _evaluate_impl( + state, + # pyre-ignore: Incompatible parameter type [6] + train_unit, + callback_handler, + ) + logger.info("Finished evaluation. Resuming training epoch") + state._active_phase = ActivePhase.TRAIN except StopIteration: break diff --git a/torchtnt/framework/utils.py b/torchtnt/framework/utils.py index 697bcdc590..6826b59a4e 100644 --- a/torchtnt/framework/utils.py +++ b/torchtnt/framework/utils.py @@ -8,12 +8,26 @@ import contextlib import inspect import logging + +import time from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) import torch import torch.nn as nn import typing_extensions +from pyre_extensions import none_throws from torch.distributed.fsdp import ( FullyShardedDataParallel, FullyShardedDataParallel as FSDP, @@ -25,8 +39,9 @@ from torch.distributed._composable_state import _get_module_state from torch.distributed.fsdp._common_utils import _FSDPState -from pyre_extensions import none_throws -from torchtnt.framework.state import ActivePhase, EntryPoint, State + +from torchtnt.framework.callback import Callback +from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.framework.unit import AppStateMixin from torchtnt.utils.lr_scheduler import TLRScheduler from torchtnt.utils.progress import Progress @@ -34,6 +49,7 @@ _logger: logging.Logger = logging.getLogger(__name__) +T = TypeVar("T") # Helper functions common across the loops def _is_done( @@ -105,6 +121,58 @@ def _get_timing_context(state: State, event_name: str, skip_timer: bool = False) yield (timer_context, profiler_context) +@contextmanager +def _log_time_waiting_for_data(state: State) -> Iterator[None]: + """Returns a context manager to use around the call to __next__, which will + log the time it takes to receive the data""" + try: + start_time: float = time.perf_counter() + yield + finally: + total_time: float = time.perf_counter() - start_time + _get_active_state(state)._last_iteration_data_wait_time = total_time + + +@contextmanager +def _log_iteration_time(state: State) -> Iterator[None]: + """Returns a context manager to use around the call to __next__, which will + log the time it takes to receive the data""" + try: + start_time: float = time.perf_counter() + yield + finally: + total_time: float = time.perf_counter() - start_time + _get_active_state(state)._last_iteration_time = total_time + + +def _get_active_state(state: State) -> PhaseState: + if state.active_phase == ActivePhase.TRAIN: + return none_throws(state.train_state) + elif state.active_phase == ActivePhase.EVALUATE: + return none_throws(state.eval_state) + elif state.active_phase == ActivePhase.PREDICT: + return none_throws(state.predict_state) + else: + raise ValueError( + f"State is not in one of the three active states, {state.active_phase}" + ) + + +def _run_callback_fn( + callbacks: List[Callback], + fn_name: str, + state: State, + *args: Any, + **kwargs: Any, +) -> None: + for cb in callbacks: + fn = getattr(cb, fn_name) + if not callable(fn): + raise ValueError(f"Invalid callback method name provided: {fn_name}") + with _get_timing_context(state, f"{cb.name}.{fn_name}"): + fn(state, *args, **kwargs) + + def log_api_usage(entry_point: str) -> None: torch._C._log_api_usage_once(f"torchtnt.framework.{entry_point}")