Skip to content

Commit

Permalink
Add eval/predict dataloader parameters to restore methods (#917)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #917

Differential Revision: D63013005
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Oct 4, 2024
1 parent ad5a219 commit 8b2beb8
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
14 changes: 12 additions & 2 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tempfile
import time
import unittest
from typing import cast, Iterable, List, Optional
from typing import Any, cast, Iterable, List, Optional
from unittest.mock import MagicMock, patch

import torch
Expand Down Expand Up @@ -42,7 +42,14 @@
from torchtnt.framework.state import ActivePhase, State

from torchtnt.framework.train import train
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData, TTrainUnit
from torchtnt.framework.unit import (
AppStateMixin,
TEvalData,
TPredictData,
TrainUnit,
TTrainData,
TTrainUnit,
)
from torchtnt.utils.checkpoint import BestCheckpointConfig, get_latest_checkpoint_path
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import init_from_env
Expand Down Expand Up @@ -94,10 +101,13 @@ def restore(
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
eval_dataloader: Optional[Iterable[TEvalData]] = None,
predict_dataloader: Optional[Iterable[TPredictData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
msg: str = "",
restored_checkpoint_path: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
if restored_checkpoint_path is not None:
if len(restored_checkpoint_path):
Expand Down
5 changes: 4 additions & 1 deletion torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def restore(
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
**kwargs: Any,
) -> None:
"""Method to restore checkpoint state from a path.
Expand All @@ -419,7 +420,7 @@ def restore(
Args:
path: Path of the checkpoint to restore.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
train_dataloader: An optional train dataloader to restore. Can only be used when restoring from a train or fit checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
"""
Expand Down Expand Up @@ -538,6 +539,7 @@ def restore_with_id(
train_dataloader: Optional[Iterable[TTrainData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
**kwargs: Any,
) -> None:
"""Method to restore checkpoint state from a checkpoint id.
Expand All @@ -561,4 +563,5 @@ def restore_with_id(
train_dataloader=train_dataloader,
process_group=process_group,
restore_options=restore_options,
**kwargs,
)
14 changes: 13 additions & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
from torchtnt.framework.state import State
from torchtnt.framework.unit import (
AppStateMixin,
TEvalData,
TEvalUnit,
TPredictData,
TPredictUnit,
TTrainData,
TTrainUnit,
Expand Down Expand Up @@ -226,11 +228,14 @@ def restore(
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
eval_dataloader: Optional[Iterable[TEvalData]] = None,
predict_dataloader: Optional[Iterable[TPredictData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
knob_options: Optional[KnobOptions] = None,
planner: Optional[LoadPlanner] = None,
storage_reader: Optional[StorageReader] = None,
**kwargs: Any,
) -> None:
"""Utility method to restore dcp checkpoint from a path."""

Expand All @@ -240,6 +245,8 @@ def restore(
checkpoint_id,
unit,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
predict_dataloader=predict_dataloader,
process_group=process_group,
restore_options=restore_options,
knob_options=knob_options,
Expand All @@ -253,11 +260,14 @@ def restore_with_id(
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
eval_dataloader: Optional[Iterable[TEvalData]] = None,
predict_dataloader: Optional[Iterable[TPredictData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
knob_options: Optional[KnobOptions] = None,
planner: Optional[LoadPlanner] = None,
storage_reader: Optional[StorageReader] = None,
**kwargs: Any,
) -> None:
"""Utility method to restore dcp checkpoint from a checkpoint_id.
Expand All @@ -267,7 +277,9 @@ def restore_with_id(
Args:
checkpoint_id: Checkpoint id. It can be the path of the snapshot to restore.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
train_dataloader: An optional train dataloader to restore. Can only be used when restoring from a train or fit checkpoint.
eval_dataloader: An optional eval dataloader to restore. Can only be used when restoring from an eval or fit checkpoint.
predict_dataloader: An optional predict dataloader to restore. Can only be used when restoring from a predict checkpoint.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
If not Gloo, a Gloo process group is created.
Note: If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
Expand Down
3 changes: 2 additions & 1 deletion torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def restore(
storage_options: Optional[Dict[str, Any]] = None,
knob_options: Optional[KnobOptions] = None,
strict: bool = True,
**kwargs: Any,
) -> None:
"""Utility method to restore snapshot state from a path.
Expand All @@ -279,7 +280,7 @@ def restore(
Args:
path: Path of the snapshot to restore.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
train_dataloader: An optional train dataloader to restore. Note that restoring from predict or evaluate dataloaders is not supported for TorchSnapshotSaver.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
Expand Down

0 comments on commit 8b2beb8

Please sign in to comment.