Skip to content

Commit

Permalink
static docs - split utils by module
Browse files Browse the repository at this point in the history
Summary: Splitting utils static docs by utils module

Differential Revision: D49612893
  • Loading branch information
galrotem authored and facebook-github-bot committed Sep 25, 2023
1 parent cd2c4f8 commit 52f7a1f
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 3 deletions.
270 changes: 267 additions & 3 deletions docs/source/utils/utils.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,270 @@
Utils
=============

.. automodule:: torchtnt.utils
:members:
:undoc-members:
Training related utilities. These are independent of the framework and can be used as needed.


Device Utils
~~~~~~~~~~~~~~~~~~~~~


.. currentmodule:: torchtnt.utils.device
.. autosummary::
:toctree: generated
:nosignatures:

get_device_from_env
copy_data_to_device
record_data_in_stream
get_nvidia_smi_gpu_stats
get_psutil_cpu_stats
maybe_enable_tf32


Distributed Utils
~~~~~~~~~~~~~~~~~~~~~


.. currentmodule:: torchtnt.utils.distributed
.. autosummary::
:toctree: generated
:nosignatures:

PGWrapper
get_global_rank
get_local_rank
get_world_size
barrier
destroy_process_group
get_process_group_backend_from_device
get_file_init_method
get_tcp_init_method
all_gather_tensors
rank_zero_fn
revert_sync_batchnorm
sync_bool


Early Stop Checker
~~~~~~~~~~~~~~~~~~~~~


.. currentmodule:: torchtnt.utils.early_stop_checker
.. autosummary::
:toctree: generated
:nosignatures:

EarlyStopChecker


Environment Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.env
.. autosummary::
:toctree: generated
:nosignatures:

init_from_env
seed


Flops Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.flops
.. autosummary::
:toctree: generated
:nosignatures:

FlopTensorDispatchMode


Filesystem Spec Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.fsspec
.. autosummary::
:toctree: generated
:nosignatures:

get_filesystem


Memory Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.memory
.. autosummary::
:toctree: generated
:nosignatures:

RSSProfiler
measure_rss_deltas


Module Summary Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.module_summary
.. autosummary::
:toctree: generated
:nosignatures:

ModuleSummary
get_module_summary
get_summary_table
prune_module_summary


OOM Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.oom
.. autosummary::
:toctree: generated
:nosignatures:

is_out_of_cpu_memory
is_out_of_cuda_memory
is_out_of_memory_error
log_memory_snapshot
attach_oom_observer


Precision Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.precision
.. autosummary::
:toctree: generated
:nosignatures:

convert_precision_str_to_dtype


Prepare Module Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.prepare_module
.. autosummary::
:toctree: generated
:nosignatures:

prepare_module
prepare_ddp
prepare_fsdp
convert_str_to_strategy


Progress Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.progress
.. autosummary::
:toctree: generated
:nosignatures:

Progress
estimated_steps_in_epoch
estimated_steps_in_loop
estimated_steps_in_fit


Rank Zero Log Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.rank_zero_log
.. autosummary::
:toctree: generated
:nosignatures:

rank_zero_print
rank_zero_debug
rank_zero_info
rank_zero_warn
rank_zero_error
rank_zero_critical


Stateful
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.stateful
.. autosummary::
:toctree: generated
:nosignatures:

Stateful


Test Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.test_utils
.. autosummary::
:toctree: generated
:nosignatures:

get_pet_launch_config
is_asan
is_tsan
skip_if_asan
spawn_multi_process


Timer Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.timer
.. autosummary::
:toctree: generated
:nosignatures:

log_elapsed_time
TimerProtocol
Timer
FullSyncPeriodicTimer
BoundedTimer
get_timer_summary
get_durations_histogram
get_synced_durations_histogram
get_synced_timer_histogram
get_recorded_durations_table


TQDM Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.tqdm
.. autosummary::
:toctree: generated
:nosignatures:

create_progress_bar
update_progress_bar
close_progress_bar


Version Utils
~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torchtnt.utils.version
.. autosummary::
:toctree: generated
:nosignatures:

is_windows
get_python_version
get_torch_version


Misc Utils
~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torchtnt.utils.misc
.. autosummary::
:toctree: generated
:nosignatures:

days_to_secs
transfer_batch_norm_stats
3 changes: 3 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
rank_zero_warn,
)
from .stateful import Stateful
from .test_utils import get_pet_launch_config, spawn_multi_process
from .timer import FullSyncPeriodicTimer, get_timer_summary, log_elapsed_time, Timer
from .tqdm import close_progress_bar, create_progress_bar, update_progress_bar
from .version import (
Expand Down Expand Up @@ -143,4 +144,6 @@
"is_torch_version_geq_1_9",
"is_torch_version_geq_2_0",
"is_windows",
"get_pet_launch_config",
"spawn_multi_process",
]
1 change: 1 addition & 0 deletions torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any

def record_data_in_stream(data: T, stream: torch.cuda.streams.Stream) -> None:
"""
Records the tensor element on certain streams, to avoid memory from being reused for another tensor.
As mentioned in
https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, PyTorch
uses the "caching allocator" for memory allocation for tensors. When a tensor is
Expand Down
11 changes: 11 additions & 0 deletions torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,17 @@ def prepare_module(
torch_compile_params: Optional[TorchCompileParams] = None,
activation_checkpoint_params: Optional[ActivationCheckpointParams] = None,
) -> torch.nn.Module:
"""
Utility to move a module to device, set up parallelism, activation checkpointing and compile.
Args:
module: module to be used.
device: device to which module will be moved.
strategy: the data parallelization strategy to be used. if a string, must be one of ``ddp`` or ``fsdp``.
swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging.
torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html.
activation_checkpoint_params: params for enabling activation checkpointing.
"""

if strategy:
if isinstance(strategy, str):
Expand Down
6 changes: 6 additions & 0 deletions torchtnt/utils/rank_zero_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


def rank_zero_print(*args: Any, **kwargs: Any) -> None:
"""Call print function only from rank 0."""
if get_global_rank() != 0:
return
print(*args, **kwargs)
Expand All @@ -25,6 +26,7 @@ def rank_zero_print(*args: Any, **kwargs: Any) -> None:
def rank_zero_debug(
*args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any
) -> None:
"""Log debug message only from rank 0."""
if get_global_rank() != 0:
return
logger = logger or _LOGGER
Expand All @@ -36,6 +38,7 @@ def rank_zero_debug(
def rank_zero_info(
*args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any
) -> None:
"""Log info message only from rank 0."""
if get_global_rank() != 0:
return
logger = logger or _LOGGER
Expand All @@ -47,6 +50,7 @@ def rank_zero_info(
def rank_zero_warn(
*args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any
) -> None:
"""Log warn message only from rank 0."""
if get_global_rank() != 0:
return
logger = logger or _LOGGER
Expand All @@ -58,6 +62,7 @@ def rank_zero_warn(
def rank_zero_error(
*args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any
) -> None:
"""Log error message only from rank 0."""
if get_global_rank() != 0:
return
logger = logger or _LOGGER
Expand All @@ -69,6 +74,7 @@ def rank_zero_error(
def rank_zero_critical(
*args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any
) -> None:
"""Log critical message only from rank 0."""
if get_global_rank() != 0:
return
logger = logger or _LOGGER
Expand Down
3 changes: 3 additions & 0 deletions torchtnt/utils/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def create_progress_bar(
max_steps: Optional[int],
max_steps_per_epoch: Optional[int],
) -> tqdm:
"""Constructs a :func:`tqdm` progress bar."""
current_epoch = num_epochs_completed
total = estimated_steps_in_epoch(
dataloader,
Expand All @@ -41,13 +42,15 @@ def create_progress_bar(
def update_progress_bar(
progress_bar: tqdm, num_steps_completed: int, refresh_rate: int
) -> None:
"""Updates a progress bar to reflect the number of steps completed."""
if num_steps_completed % refresh_rate == 0:
progress_bar.update(refresh_rate)


def close_progress_bar(
progress_bar: tqdm, num_steps_completed: int, refresh_rate: int
) -> None:
"""Updates and closes a progress bar."""
# complete remaining progress in bar
progress_bar.update(num_steps_completed % refresh_rate)
progress_bar.close()

0 comments on commit 52f7a1f

Please sign in to comment.