Skip to content

Commit

Permalink
Add stateful protocol to utilities
Browse files Browse the repository at this point in the history
Differential Revision: D47410138

fbshipit-source-id: 982a3c57b0d0954b80981a8f72e76f76191840d5
  • Loading branch information
ananthsub authored and facebook-github-bot committed Jul 12, 2023
1 parent fe247c8 commit 5d7884f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
14 changes: 5 additions & 9 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,19 @@

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import (
_Stateful as StatefulProtocol,
AppStateMixin,
TEvalUnit,
TPredictUnit,
TTrainUnit,
)
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.framework.utils import _construct_tracked_optimizers
from torchtnt.utils import get_global_rank, rank_zero_info, rank_zero_warn
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import Stateful

try:
import torchsnapshot

_TStateful = torchsnapshot.Stateful
_TORCHSNAPSHOT_AVAILABLE = True
except Exception:
_TStateful = StatefulProtocol
_TStateful = Stateful
_TORCHSNAPSHOT_AVAILABLE = False

_EVAL_PROGRESS_STATE_KEY = "eval_progress"
Expand Down
15 changes: 3 additions & 12 deletions torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,14 @@
import torch

from torchtnt.framework.state import State
from torchtnt.utils import TLRScheduler
from typing_extensions import Protocol, runtime_checkable
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.stateful import Stateful

"""
This file defines mixins and interfaces for users to customize hooks in training, evaluation, and prediction loops.
"""


@runtime_checkable
class _Stateful(Protocol):
def state_dict(self) -> Dict[str, Any]:
...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...


def _remove_from_dicts(name_to_remove: str, *dicts: Dict[str, Any]) -> None:
for d in dicts:
if name_to_remove in d:
Expand Down Expand Up @@ -131,7 +122,7 @@ def __setattr__(self, name: str, value: Any) -> None:
value,
self.__dict__.get("_lr_schedulers"),
)
elif isinstance(value, _Stateful):
elif isinstance(value, Stateful):
self._update_attr(
name,
value,
Expand Down
2 changes: 2 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
rank_zero_print,
rank_zero_warn,
)
from .stateful import Stateful
from .timer import FullSyncPeriodicTimer, get_timer_summary, Timer
from .tqdm import close_progress_bar, create_progress_bar, update_progress_bar
from .version import (
Expand Down Expand Up @@ -105,6 +106,7 @@
"rank_zero_info",
"rank_zero_print",
"rank_zero_warn",
"Stateful",
"FullSyncPeriodicTimer",
"get_timer_summary",
"transfer_batch_norm_stats",
Expand Down
20 changes: 20 additions & 0 deletions torchtnt/utils/stateful.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict

from typing_extensions import Protocol, runtime_checkable


@runtime_checkable
class Stateful(Protocol):
"""Defines the interface for checkpoint saving and loading."""

def state_dict(self) -> Dict[str, Any]:
...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...

0 comments on commit 5d7884f

Please sign in to comment.