diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index c94033a0db..e43ad50531 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -14,7 +14,7 @@ import tempfile import time import unittest -from typing import cast, Iterable, List, Optional +from typing import Any, cast, Iterable, List, Optional from unittest.mock import MagicMock, patch import torch @@ -42,7 +42,14 @@ from torchtnt.framework.state import ActivePhase, State from torchtnt.framework.train import train -from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData, TTrainUnit +from torchtnt.framework.unit import ( + AppStateMixin, + TEvalData, + TPredictData, + TrainUnit, + TTrainData, + TTrainUnit, +) from torchtnt.utils.checkpoint import BestCheckpointConfig, get_latest_checkpoint_path from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import init_from_env @@ -94,10 +101,13 @@ def restore( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, + eval_dataloader: Optional[Iterable[TEvalData]] = None, + predict_dataloader: Optional[Iterable[TPredictData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, msg: str = "", restored_checkpoint_path: Optional[List[str]] = None, + **kwargs: Any, ) -> None: if restored_checkpoint_path is not None: if len(restored_checkpoint_path): diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 19a9a70410..3f3868f3d6 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -410,6 +410,7 @@ def restore( train_dataloader: Optional[Iterable[TTrainData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, + **kwargs: Any, ) -> None: """Method to restore checkpoint state from a path. @@ -419,7 +420,7 @@ def restore( Args: path: Path of the checkpoint to restore. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. - train_dataloader: An optional train dataloader to restore. + train_dataloader: An optional train dataloader to restore. Can only be used when restoring from a train or fit checkpoint. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) restore_options: Controls what to filter when restoring the state. """ @@ -538,6 +539,7 @@ def restore_with_id( train_dataloader: Optional[Iterable[TTrainData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, + **kwargs: Any, ) -> None: """Method to restore checkpoint state from a checkpoint id. @@ -561,4 +563,5 @@ def restore_with_id( train_dataloader=train_dataloader, process_group=process_group, restore_options=restore_options, + **kwargs, ) diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 9ee8e32a84..06363ebf35 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -37,7 +37,9 @@ from torchtnt.framework.state import State from torchtnt.framework.unit import ( AppStateMixin, + TEvalData, TEvalUnit, + TPredictData, TPredictUnit, TTrainData, TTrainUnit, @@ -226,11 +228,14 @@ def restore( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, + eval_dataloader: Optional[Iterable[TEvalData]] = None, + predict_dataloader: Optional[Iterable[TPredictData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, knob_options: Optional[KnobOptions] = None, planner: Optional[LoadPlanner] = None, storage_reader: Optional[StorageReader] = None, + **kwargs: Any, ) -> None: """Utility method to restore dcp checkpoint from a path.""" @@ -240,6 +245,8 @@ def restore( checkpoint_id, unit, train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + predict_dataloader=predict_dataloader, process_group=process_group, restore_options=restore_options, knob_options=knob_options, @@ -253,11 +260,14 @@ def restore_with_id( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, + eval_dataloader: Optional[Iterable[TEvalData]] = None, + predict_dataloader: Optional[Iterable[TPredictData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, knob_options: Optional[KnobOptions] = None, planner: Optional[LoadPlanner] = None, storage_reader: Optional[StorageReader] = None, + **kwargs: Any, ) -> None: """Utility method to restore dcp checkpoint from a checkpoint_id. @@ -267,7 +277,9 @@ def restore_with_id( Args: checkpoint_id: Checkpoint id. It can be the path of the snapshot to restore. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. - train_dataloader: An optional train dataloader to restore. + train_dataloader: An optional train dataloader to restore. Can only be used when restoring from a train or fit checkpoint. + eval_dataloader: An optional eval dataloader to restore. Can only be used when restoring from an eval or fit checkpoint. + predict_dataloader: An optional predict dataloader to restore. Can only be used when restoring from a predict checkpoint. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) If not Gloo, a Gloo process group is created. Note: If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion. diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 363b2e772d..713324d812 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -270,6 +270,7 @@ def restore( storage_options: Optional[Dict[str, Any]] = None, knob_options: Optional[KnobOptions] = None, strict: bool = True, + **kwargs: Any, ) -> None: """Utility method to restore snapshot state from a path. @@ -279,7 +280,7 @@ def restore( Args: path: Path of the snapshot to restore. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. - train_dataloader: An optional train dataloader to restore. + train_dataloader: An optional train dataloader to restore. Note that restoring from predict or evaluate dataloaders is not supported for TorchSnapshotSaver. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) restore_options: Controls what to filter when restoring the state. storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations.