Skip to content

Commit

Permalink
Measure times (#439)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #439

Add two counters, containing the iteration and time blocked on data for the last iteration. These are stored on the state, not sure if the unit would be a better place.

They can be accessed at the start of the next iteration. This is done because we want to capture the time of the whole iteration, including callbacks, so it cannot be accessed by callbacks.

Another possibility would be to keep a list of all the times and data times, not sure if that would be preferrable.

Reviewed By: daniellepintz

Differential Revision: D46853794

fbshipit-source-id: 439a9ae17bf867f5722cd86f21cd42599966bd9c
  • Loading branch information
Miquel Jubert Hermoso authored and facebook-github-bot committed Jul 17, 2023
1 parent 3cbec3c commit 04b08fb
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 38 deletions.
12 changes: 12 additions & 0 deletions tests/framework/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def test_train(self) -> None:

self.assertEqual(my_unit.module.training, initial_training_mode)
self.assertEqual(state.entry_point, EntryPoint.TRAIN)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_max_steps_per_epoch(self) -> None:
"""
Expand Down Expand Up @@ -85,6 +87,8 @@ def test_train_max_steps_per_epoch(self) -> None:
self.assertEqual(state.train_state.step_output, None)

self.assertEqual(my_unit.module.training, initial_training_mode)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_stop(self) -> None:
"""
Expand Down Expand Up @@ -116,6 +120,8 @@ def test_train_stop(self) -> None:
my_unit.steps_processed, state.train_state.progress.num_steps_completed
)
self.assertEqual(my_unit.steps_processed, steps_before_stopping)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_with_callback(self) -> None:
"""
Expand Down Expand Up @@ -148,6 +154,8 @@ def test_train_with_callback(self) -> None:
)
self.assertEqual(callback_mock.on_train_epoch_end.call_count, max_epochs)
self.assertEqual(callback_mock.on_train_end.call_count, 1)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_data_iter_step(self) -> None:
class TrainIteratorUnit(TrainUnit[Iterator[Tuple[torch.Tensor, torch.Tensor]]]):
Expand Down Expand Up @@ -190,6 +198,8 @@ def train_step(
self.assertEqual(state.train_state.step_output, None)

self.assertEqual(my_unit.module.training, initial_training_mode)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_max_steps(self) -> None:
max_steps = 3
Expand Down Expand Up @@ -227,6 +237,8 @@ def test_train_max_steps(self) -> None:
self.assertEqual(
my_unit.train_step.call_count, max_epochs * expected_steps_per_epoch
)
self.assertNotEqual(state.train_state.last_iteration_time, None)
self.assertNotEqual(state.train_state.last_iteration_data_wait_time, None)

def test_train_auto_timing(self) -> None:
"""
Expand Down
7 changes: 5 additions & 2 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torchtnt.framework.utils import (
_get_timing_context,
_is_fsdp_module,
_log_time_waiting_for_data,
get_current_progress,
)
from torchtnt.utils.device import copy_data_to_device, record_data_in_stream
Expand Down Expand Up @@ -628,8 +629,10 @@ def train_step(
self, state: State, data: Iterator[TData]
) -> Tuple[torch.Tensor, Any]:
train_state = none_throws(state.train_state)

batch = self._get_next_batch(state, data)
# __next__ is already decorated and will log this, this call will over-write it
# to set an accurate number according to the pre-fetching
with _log_time_waiting_for_data(state):
batch = self._get_next_batch(state, data)

should_update_weights = (
train_state.progress.num_steps_completed_in_epoch + 1
Expand Down
14 changes: 14 additions & 0 deletions torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(
self._evaluate_every_n_epochs = evaluate_every_n_epochs

self._step_output: Any = None
self._last_iteration_time: Optional[float] = None
self._last_iteration_data_wait_time: Optional[float] = None

@property
def dataloader(self) -> Iterable[Any]:
Expand Down Expand Up @@ -132,6 +134,18 @@ def step_output(self) -> Any:
"""Output of the last step."""
return self._step_output

@property
def last_iteration_time(self) -> Any:
"""Time it took to run through the last iteration."""
return self._last_iteration_time

@property
def last_iteration_data_wait_time(self) -> Any:
"""Time the process was waiting for data from the dataloader. Usually equivalent
to timing __next__(self):
"""
return self._last_iteration_data_wait_time


class State:
"""Parent State class which can contain up to 3 instances of PhaseState, for the 3 phases.
Expand Down
71 changes: 38 additions & 33 deletions torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
_get_timing_context,
_is_done,
_is_epoch_done,
_log_iteration_time,
_log_time_waiting_for_data,
_maybe_set_distributed_sampler_epoch,
_reset_module_training_mode,
_set_module_training_mode,
Expand Down Expand Up @@ -231,40 +233,43 @@ def _train_epoch_impl(
)
):
try:
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with _get_timing_context(state, "train.next(data_iter)"):
step_input = next(data_iter)

callback_handler.on_train_step_start(state, train_unit)

with _get_timing_context(
state,
f"{train_unit.__class__.__name__}.train_step",
skip_timer=is_auto_unit,
# skip timer if train_unit is a subclass of AutoUnit because there is additional timing in the AutoUnit, and all timing should be mutually exclusive
):
train_state._step_output = train_unit.train_step(state, step_input)

train_state.progress.increment_step()
callback_handler.on_train_step_end(state, train_unit)

# clear step_output to avoid retaining extra memory
train_state._step_output = None

if (
evaluate_every_n_steps
and train_state.progress.num_steps_completed % evaluate_every_n_steps
== 0
):
_evaluate_impl(
with _log_time_waiting_for_data(state):
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with _get_timing_context(state, "train.next(data_iter)"):
step_input = next(data_iter)

with _log_iteration_time(state):
callback_handler.on_train_step_start(state, train_unit)

with _get_timing_context(
state,
# pyre-ignore: Incompatible parameter type [6]
train_unit,
callback_handler,
)
logger.info("Finished evaluation. Resuming training epoch")
state._active_phase = ActivePhase.TRAIN
f"{train_unit.__class__.__name__}.train_step",
skip_timer=is_auto_unit,
# skip timer if train_unit is a subclass of AutoUnit because there is additional timing in the AutoUnit, and all timing should be mutually exclusive
):
train_state._step_output = train_unit.train_step(state, step_input)

train_state.progress.increment_step()
callback_handler.on_train_step_end(state, train_unit)

# clear step_output to avoid retaining extra memory
train_state._step_output = None

if (
evaluate_every_n_steps
and train_state.progress.num_steps_completed
% evaluate_every_n_steps
== 0
):
_evaluate_impl(
state,
# pyre-ignore: Incompatible parameter type [6]
train_unit,
callback_handler,
)
logger.info("Finished evaluation. Resuming training epoch")
state._active_phase = ActivePhase.TRAIN

except StopIteration:
break
Expand Down
74 changes: 71 additions & 3 deletions torchtnt/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,26 @@
import contextlib
import inspect
import logging

import time
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)

import torch
import torch.nn as nn
import typing_extensions
from pyre_extensions import none_throws
from torch.distributed.fsdp import (
FullyShardedDataParallel,
FullyShardedDataParallel as FSDP,
Expand All @@ -25,15 +39,17 @@
from torch.distributed._composable_state import _get_module_state
from torch.distributed.fsdp._common_utils import _FSDPState

from pyre_extensions import none_throws
from torchtnt.framework.state import ActivePhase, EntryPoint, State

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.framework.unit import AppStateMixin
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.progress import Progress


_logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T")

# Helper functions common across the loops
def _is_done(
Expand Down Expand Up @@ -105,6 +121,58 @@ def _get_timing_context(state: State, event_name: str, skip_timer: bool = False)
yield (timer_context, profiler_context)


@contextmanager
def _log_time_waiting_for_data(state: State) -> Iterator[None]:
"""Returns a context manager to use around the call to __next__, which will
log the time it takes to receive the data"""
try:
start_time: float = time.perf_counter()
yield
finally:
total_time: float = time.perf_counter() - start_time
_get_active_state(state)._last_iteration_data_wait_time = total_time


@contextmanager
def _log_iteration_time(state: State) -> Iterator[None]:
"""Returns a context manager to use around the call to __next__, which will
log the time it takes to receive the data"""
try:
start_time: float = time.perf_counter()
yield
finally:
total_time: float = time.perf_counter() - start_time
_get_active_state(state)._last_iteration_time = total_time


def _get_active_state(state: State) -> PhaseState:
if state.active_phase == ActivePhase.TRAIN:
return none_throws(state.train_state)
elif state.active_phase == ActivePhase.EVALUATE:
return none_throws(state.eval_state)
elif state.active_phase == ActivePhase.PREDICT:
return none_throws(state.predict_state)
else:
raise ValueError(
f"State is not in one of the three active states, {state.active_phase}"
)


def _run_callback_fn(
callbacks: List[Callback],
fn_name: str,
state: State,
*args: Any,
**kwargs: Any,
) -> None:
for cb in callbacks:
fn = getattr(cb, fn_name)
if not callable(fn):
raise ValueError(f"Invalid callback method name provided: {fn_name}")
with _get_timing_context(state, f"{cb.name}.{fn_name}"):
fn(state, *args, **kwargs)


def log_api_usage(entry_point: str) -> None:
torch._C._log_api_usage_once(f"torchtnt.framework.{entry_point}")

Expand Down

0 comments on commit 04b08fb

Please sign in to comment.