From b136b4e9c6f50467b868a738d88d298323fd3992 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Fri, 4 Oct 2024 14:47:56 -0700 Subject: [PATCH 1/5] Use no_grad instead of inference_mode for predict with checkpointing (#912) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/912 Reviewed By: mvsfb, JKSenthil Differential Revision: D63491475 --- tests/framework/test_predict.py | 23 +++++++++++++++++++++-- torchtnt/framework/predict.py | 15 ++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/tests/framework/test_predict.py b/tests/framework/test_predict.py index 68e7059453..b33ba30128 100644 --- a/tests/framework/test_predict.py +++ b/tests/framework/test_predict.py @@ -8,14 +8,15 @@ # pyre-strict import unittest -from typing import Any, Iterator, Mapping, Tuple -from unittest.mock import MagicMock +from typing import Any, cast, Iterator, List, Mapping, Tuple +from unittest.mock import MagicMock, patch import torch from torch import nn from torchtnt.framework._test_utils import DummyPredictUnit, generate_random_dataloader from torchtnt.framework.callback import Callback +from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver from torchtnt.framework.predict import predict from torchtnt.framework.state import State @@ -223,6 +224,24 @@ def test_error_message(self) -> None: log.output, ) + def test_predict_ckpt_autograd_mode( + self, + ) -> None: + """ + Verify that the pytorch autograd mode used depends on having a checkpoint callback in predict. + """ + unit = DummyPredictUnit(2) + dataloader = generate_random_dataloader(10, 2, 2) + dcp_saver = DistributedCheckpointSaver(dirpath="dummy_dirpath") + + for callbacks, mock_autograd_mode in [ + ([], "torch.inference_mode"), + ([dcp_saver], "torch.no_grad"), + ]: + with patch(mock_autograd_mode) as mock_autograd_mode: + predict(unit, dataloader, callbacks=cast(List[Callback], callbacks)) + mock_autograd_mode.assert_called_once() + Batch = Tuple[torch.Tensor, torch.Tensor] diff --git a/torchtnt/framework/predict.py b/torchtnt/framework/predict.py index bf414ec72b..c4d362b3f6 100644 --- a/torchtnt/framework/predict.py +++ b/torchtnt/framework/predict.py @@ -20,6 +20,7 @@ _set_module_training_mode, ) from torchtnt.framework.callback import Callback +from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.framework.unit import TPredictData, TPredictUnit from torchtnt.framework.utils import get_timing_context @@ -80,7 +81,10 @@ def predict( call on_predict_end on unit first and then callbacks """ _log_api_usage("predict") - callback_handler = CallbackHandler(callbacks or []) + callbacks = callbacks or [] + callback_handler = CallbackHandler(callbacks) + checkpoint_cb_exists = any(isinstance(cb, BaseCheckpointer) for cb in callbacks) + state = State( entry_point=EntryPoint.PREDICT, predict_state=PhaseState( @@ -90,7 +94,13 @@ def predict( timer=timer, ) try: - _predict_impl(state, predict_unit, callback_handler) + # all_gather using inference_mode with gloo backend is not supported. Since this collective + # is necessary for checkpointing, we need to use torch.no_grad instead. + # TODO: remove this once all_gather is supported in inference_mode. + inference_ctx = torch.no_grad if checkpoint_cb_exists else torch.inference_mode + with inference_ctx(): + _predict_impl(state, predict_unit, callback_handler) + logger.info("Finished predict") if state.timer: logger.info(get_timer_summary(state.timer)) @@ -104,7 +114,6 @@ def predict( raise e -@torch.inference_mode() def _predict_impl( state: State, predict_unit: TPredictUnit, From 70031a9826b0e42cb82b59b0369b2b920d7d88d3 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Fri, 4 Oct 2024 14:47:56 -0700 Subject: [PATCH 2/5] Generate predict/evaluate checkpoints in BaseCheckpointer (#914) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/914 Reviewed By: JKSenthil Differential Revision: D63013008 --- .../callbacks/test_base_checkpointer.py | 84 ++++++++++++ .../framework/callbacks/base_checkpointer.py | 123 ++++++++++++++---- torchtnt/framework/callbacks/dcp_saver.py | 2 +- torchtnt/framework/unit.py | 10 ++ 4 files changed, 196 insertions(+), 23 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index d40cd2c236..c94033a0db 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -25,6 +25,7 @@ Batch, DummyAutoUnit, DummyFitUnit, + DummyPredictUnit, DummyTrainUnit, generate_random_dataloader, get_dummy_fit_state, @@ -35,7 +36,9 @@ ) from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions from torchtnt.framework.callbacks.lambda_callback import Lambda +from torchtnt.framework.evaluate import evaluate from torchtnt.framework.fit import fit +from torchtnt.framework.predict import predict from torchtnt.framework.state import ActivePhase, State from torchtnt.framework.train import train @@ -57,7 +60,9 @@ def __init__( *, save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, + save_every_n_eval_steps: Optional[int] = None, save_every_n_eval_epochs: Optional[int] = None, + save_every_n_predict_steps: Optional[int] = None, keep_last_n_checkpoints: Optional[int] = None, best_checkpoint_config: Optional[BestCheckpointConfig] = None, process_group: Optional[dist.ProcessGroup] = None, @@ -66,7 +71,9 @@ def __init__( dirpath, save_every_n_train_steps=save_every_n_train_steps, save_every_n_epochs=save_every_n_epochs, + save_every_n_eval_steps=save_every_n_eval_steps, save_every_n_eval_epochs=save_every_n_eval_epochs, + save_every_n_predict_steps=save_every_n_predict_steps, keep_last_n_checkpoints=keep_last_n_checkpoints, best_checkpoint_config=best_checkpoint_config, process_group=process_group, @@ -243,6 +250,83 @@ def test_save_fit_entrypoint(self) -> None: checkpointer._latest_checkpoint_path, ) + @patch.object(BaseCheckpointSaver, "_checkpoint_impl") + def test_save_eval_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None: + my_unit = DummyFitUnit(input_dim=2) + with tempfile.TemporaryDirectory() as temp_dir: + checkpointer = BaseCheckpointSaver( + temp_dir, + save_every_n_eval_steps=2, + best_checkpoint_config=BestCheckpointConfig( + monitored_metric="val_loss", mode="min" + ), + keep_last_n_checkpoints=1, + ) + + ckpt_container: List[str] = [] + + def _checkpoint_impl_side_effect( + state: State, unit: AppStateMixin, checkpoint_id: str, hook: str + ) -> bool: + ckpt_container.append(checkpoint_id) + return True + + mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect + + eval_dataloader = generate_random_dataloader(10, 2, 1) + + warning_container: List[str] = [] + with patch( + "torchtnt.framework.callbacks.base_checkpointer.logging.Logger.warning", + side_effect=warning_container.append, + ): + evaluate(my_unit, eval_dataloader, callbacks=[checkpointer]) + + # Verify that checkpoint optimality tracking was disabled + self.assertIn( + "Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints.", + warning_container, + ) + self.assertIn( + "Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints.", + warning_container, + ) + + # Make sure that the correct checkpoints were saved, without tracked metrics + expected_ckpts = [ + f"{temp_dir}/epoch_0_eval_step_{i*2}" for i in range(1, 6) + ] + self.assertEqual(ckpt_container, expected_ckpts) + + @patch.object(BaseCheckpointSaver, "_checkpoint_impl") + def test_save_predict_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None: + my_unit = DummyPredictUnit(input_dim=2) + with tempfile.TemporaryDirectory() as temp_dir: + checkpointer = BaseCheckpointSaver( + temp_dir, + save_every_n_predict_steps=1, + ) + + ckpt_container: List[str] = [] + + def _checkpoint_impl_side_effect( + state: State, unit: AppStateMixin, checkpoint_id: str, hook: str + ) -> bool: + ckpt_container.append(checkpoint_id) + return True + + mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect + + predict_dataloader = generate_random_dataloader(10, 2, 1) + + predict(my_unit, predict_dataloader, callbacks=[checkpointer]) + + # Make sure that the correct checkpoints were saved + expected_ckpts = [ + f"{temp_dir}/epoch_0_predict_step_{i}" for i in range(1, 11) + ] + self.assertEqual(ckpt_container, expected_ckpts) + @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) def test_restore_from_latest(self, mock_stdout: MagicMock) -> None: my_unit = DummyTrainUnit(input_dim=2) diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 0b98737d65..a411e80163 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -15,10 +15,19 @@ import torch.distributed as dist from pyre_extensions import none_throws from torchtnt.framework.callback import Callback -from torchtnt.framework.callbacks._checkpoint_utils import _get_step_phase_mapping +from torchtnt.framework.callbacks._checkpoint_utils import ( + _get_epoch, + _get_step_phase_mapping, +) from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions -from torchtnt.framework.state import EntryPoint, State -from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit +from torchtnt.framework.state import ActivePhase, EntryPoint, State +from torchtnt.framework.unit import ( + AppStateMixin, + TEvalUnit, + TPredictUnit, + TTrainData, + TTrainUnit, +) from torchtnt.utils.checkpoint import ( BestCheckpointConfig, CheckpointManager, @@ -51,8 +60,11 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta): save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated. save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated. save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit. - keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. - best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. + save_every_n_eval_steps: Frequency of evaluation steps with which to save checkpoints during training. Use this if wanting to save checkpoints during evaluate. + save_every_n_predict_steps: Frequency of prediction steps with which to save checkpoints during training. Use this if wanting to save checkpoints during using predict entrypoint. + keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted + to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. Only supported for train or fit entrypoints. + best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. This param is ignored if not in train or fit entrypoints. process_group: The process group on which the ranks will communicate on. If the process group is not gloo-based, a new gloo-based process group will be created. Note: @@ -78,6 +90,8 @@ def __init__( save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, save_every_n_eval_epochs: Optional[int] = None, + save_every_n_eval_steps: Optional[int] = None, + save_every_n_predict_steps: Optional[int] = None, keep_last_n_checkpoints: Optional[int] = None, best_checkpoint_config: Optional[BestCheckpointConfig] = None, process_group: Optional[dist.ProcessGroup] = None, @@ -90,12 +104,23 @@ def __init__( raise ValueError( f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}" ) + if save_every_n_eval_steps is not None and save_every_n_eval_steps <= 0: + raise ValueError( + f"Invalid value passed for save_every_n_eval_steps. Expected to receive either None or positive number, but received {save_every_n_eval_steps}" + ) + if save_every_n_eval_epochs is not None and save_every_n_eval_epochs <= 0: + raise ValueError( + f"Invalid value passed for save_every_n_eval_epochs. Expected to receive either None or positive number, but received {save_every_n_eval_epochs}" + ) + if save_every_n_predict_steps is not None and save_every_n_predict_steps <= 0: + raise ValueError( + f"Invalid value passed for save_every_n_predict_steps. Expected to receive either None or positive number, but received {save_every_n_predict_steps}" + ) if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0: raise ValueError( f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}" ) - self._best_checkpoint_config = best_checkpoint_config if best_checkpoint_config and best_checkpoint_config.mode not in {"min", "max"}: raise ValueError( f"Invalid value passed for best_checkpoint_config.mode. Expected to receive 'min' or 'max', but received {best_checkpoint_config.mode}" @@ -104,7 +129,10 @@ def __init__( self._save_every_n_train_steps = save_every_n_train_steps self._save_every_n_epochs = save_every_n_epochs self._save_every_n_eval_epochs = save_every_n_eval_epochs + self._save_every_n_eval_steps = save_every_n_eval_steps + self._save_every_n_predict_steps = save_every_n_predict_steps self._keep_last_n_checkpoints = keep_last_n_checkpoints + self._best_checkpoint_config = best_checkpoint_config self._process_group: Optional[dist.ProcessGroup] = None self._setup_gloo_pg(process_group) @@ -147,7 +175,7 @@ def dirpath(self) -> str: return self._checkpoint_manager.dirpath def _generate_checkpoint_and_upkeep( - self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str + self, state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit], hook: str ) -> bool: """ Implementation for saving checkpoint while taking care of checkpoint @@ -162,11 +190,16 @@ def _generate_checkpoint_and_upkeep( True if checkpoint was successfully saved. False otherwise. """ # 1) generate checkpoint name - epoch = cast(TTrainUnit, unit).train_progress.num_epochs_completed + epoch = _get_epoch(state, unit) step_mapping = _get_step_phase_mapping(state, unit) + # 1.1) append metric data only for train checkpoints, if best_checkpoint_config is defined metric_data: Optional[MetricData] = None - if metric_value := self._get_tracked_metric_value(unit): + if ( + self._best_checkpoint_config + and state.active_phase == ActivePhase.TRAIN + and (metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit))) + ): metric_data = MetricData( name=none_throws(self._best_checkpoint_config).monitored_metric, value=metric_value, @@ -179,7 +212,8 @@ def _generate_checkpoint_and_upkeep( process_group=self._process_group, ) - # 2) Determine if we should save checkpoint + # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints + # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported. if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): return False @@ -222,9 +256,7 @@ def _generate_checkpoint_and_upkeep( return True - def _get_tracked_metric_value( - self, unit: Union[TTrainUnit, TEvalUnit] - ) -> Optional[float]: + def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]: """ If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float. @@ -271,10 +303,9 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None: def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: num_steps_completed = unit.train_progress.num_steps_completed - save_every_n_train_steps = self._save_every_n_train_steps if ( - save_every_n_train_steps is None - or num_steps_completed % save_every_n_train_steps != 0 + not self._save_every_n_train_steps + or num_steps_completed % self._save_every_n_train_steps != 0 ): return @@ -282,22 +313,70 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: epoch = unit.train_progress.num_epochs_completed - save_every_n_epochs = self._save_every_n_epochs - if save_every_n_epochs is None or epoch % save_every_n_epochs != 0: + if not self._save_every_n_epochs or epoch % self._save_every_n_epochs != 0: return self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_epoch_end") + def on_train_end(self, state: State, unit: TTrainUnit) -> None: + self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end") + + def on_eval_start(self, state: State, unit: TEvalUnit) -> None: + if state.entry_point == EntryPoint.EVALUATE: + self._disable_ckpt_optimality_tracking() + + def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: + num_steps_completed = unit.eval_progress.num_steps_completed + if ( + not self._save_every_n_eval_steps + or num_steps_completed % self._save_every_n_eval_steps != 0 + ): + return + + self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_step_end") + def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None: epoch = unit.eval_progress.num_epochs_completed - save_every_n_eval_epochs = self._save_every_n_eval_epochs - if save_every_n_eval_epochs is None or epoch % save_every_n_eval_epochs != 0: + if ( + not self._save_every_n_eval_epochs + or epoch % self._save_every_n_eval_epochs != 0 + ): return self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_epoch_end") - def on_train_end(self, state: State, unit: TTrainUnit) -> None: - self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end") + def on_predict_start(self, state: State, unit: TPredictUnit) -> None: + self._disable_ckpt_optimality_tracking() + + def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: + num_steps_completed = unit.predict_progress.num_steps_completed + if ( + not self._save_every_n_predict_steps + or num_steps_completed % self._save_every_n_predict_steps != 0 + ): + return + + self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_step_end") + + def _disable_ckpt_optimality_tracking(self) -> None: + """ + Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints + will not be used. This is useful for eval and predict entrypoints, since checkpoints do not include + model parameters. + """ + if self._best_checkpoint_config: + logger.warning( + "Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints." + ) + self._best_checkpoint_config = None + self._checkpoint_manager._best_checkpoint_config = None + + if self._keep_last_n_checkpoints: + logger.warning( + "Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints." + ) + self._keep_last_n_checkpoints = None + self._checkpoint_manager._keep_last_n_checkpoints = None @abc.abstractmethod def _checkpoint_impl( diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 14e99a1677..7cebc7f5a2 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -318,7 +318,7 @@ def restore_with_id( ) def _generate_checkpoint_and_upkeep( - self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str + self, state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit], hook: str ) -> bool: # if we are still checkpointing, this might cause a collective hang, since several # operations in the base class use the process group. So wait here instead. diff --git a/torchtnt/framework/unit.py b/torchtnt/framework/unit.py index 8a1d7ff1e4..7e470f100d 100644 --- a/torchtnt/framework/unit.py +++ b/torchtnt/framework/unit.py @@ -586,6 +586,16 @@ def on_predict_epoch_end(self, state: State) -> None: """ pass + def on_checkpoint_save(self, state: State, checkpoint_id: str) -> None: + """Hook called after successfully saving a checkpoint. + + Args: + state: a :class:`~torchtnt.framework.state.State` object containing metadata about the training run. + checkpoint_id: the ID of the checkpoint that was saved. Depending on the storage type, this may be + a path, a URL or a unique identifier. + """ + pass + def on_predict_end(self, state: State) -> None: """Hook called after prediction ends. From 0bdd4bc9f313b54db633c55bd3de9e4ef04e5925 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Fri, 4 Oct 2024 14:47:56 -0700 Subject: [PATCH 3/5] Conditionally include also train_dataloader in fit-eval checkpoints (#918) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/918 Differential Revision: D63919225 --- .../callbacks/test_checkpoint_utils.py | 28 +++++++++++++++++ .../framework/callbacks/_checkpoint_utils.py | 30 ++++++++++++++----- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/tests/framework/callbacks/test_checkpoint_utils.py b/tests/framework/callbacks/test_checkpoint_utils.py index 425aadccef..410b1f17a4 100644 --- a/tests/framework/callbacks/test_checkpoint_utils.py +++ b/tests/framework/callbacks/test_checkpoint_utils.py @@ -16,6 +16,7 @@ from torchtnt.framework._test_utils import ( DummyAutoUnit, DummyEvalUnit, + DummyFitUnit, DummyMeanMetric, DummyTrainUnit, generate_dummy_stateful_dataloader, @@ -85,6 +86,33 @@ def test_get_app_state(self) -> None: ], ) + # Test evaluate intra-epoch within train epoch on FIT (evaluate_every_n_steps) + my_unit = DummyFitUnit(input_dim=2) + my_unit.train_progress.increment_step() # Simulate at least one step in each phase + my_unit.eval_progress.increment_step() + + state = get_dummy_fit_state() + state._active_phase = ActivePhase.EVALUATE + + train_dl = generate_dummy_stateful_dataloader(1, 1, 1) + eval_dl = generate_dummy_stateful_dataloader(1, 1, 1) + none_throws(state.train_state)._dataloader = train_dl + none_throws(state.eval_state)._dataloader = eval_dl + + app_state = _prepare_app_state_for_checkpoint(state, my_unit, intra_epoch=True) + self.assertCountEqual( + app_state.keys(), + [ + "module", + "optimizer", + "loss_fn", + "train_progress", + "eval_progress", + "train_dataloader", + "eval_dataloader", + ], + ) + def test_get_step_phase_mapping(self) -> None: unit = DummyAutoUnit(module=nn.Linear(2, 2)) unit.train_progress._num_steps_completed = 5 diff --git a/torchtnt/framework/callbacks/_checkpoint_utils.py b/torchtnt/framework/callbacks/_checkpoint_utils.py index 57a2443b4c..5c3f1abf5a 100644 --- a/torchtnt/framework/callbacks/_checkpoint_utils.py +++ b/torchtnt/framework/callbacks/_checkpoint_utils.py @@ -9,8 +9,10 @@ from typing import Any, cast, Dict, Union +from pyre_extensions import none_throws + from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions -from torchtnt.framework.state import EntryPoint, State +from torchtnt.framework.state import ActivePhase, EntryPoint, State from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit from torchtnt.utils.checkpoint import Phase @@ -123,13 +125,27 @@ def _prepare_app_state_for_checkpoint( remove_lr_schedulers=True, ) + if not intra_epoch: + return app_state + # for intra-epoch checkpointing, include dataloader state of the current phase - phase_dl = state.active_phase_state().dataloader - if intra_epoch and isinstance(phase_dl, Stateful): - dataloader_state_key = _PHASE_DL_STATE_KEY_MAPPING[ - state.active_phase.into_phase() - ] - app_state[dataloader_state_key] = phase_dl + active_dataloaders = {state.active_phase: state.active_phase_state().dataloader} + + # Special case for FIT where eval is executed every n steps. We also need to save + # the train dataloader state. In this case, train epoch wouldn't be incremented yet. + if ( + state.entry_point == EntryPoint.FIT + and state.active_phase == ActivePhase.EVALUATE + and cast(TTrainUnit, unit).train_progress.num_steps_completed_in_epoch != 0 + ): + active_dataloaders[ActivePhase.TRAIN] = none_throws( + state.train_state + ).dataloader + + for active_phase, dl in active_dataloaders.items(): + if isinstance(dl, Stateful): + dl_key = _PHASE_DL_STATE_KEY_MAPPING[active_phase.into_phase()] + app_state[dl_key] = dl return app_state From c86231775948ced9b04b33e3db594c35c1ad5e0e Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Fri, 4 Oct 2024 14:47:56 -0700 Subject: [PATCH 4/5] Generate predict/evaluate ckpts in DCP Saver (#915) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/915 Reviewed By: JKSenthil Differential Revision: D63712524 --- tests/framework/callbacks/test_dcp_saver.py | 241 +++++++++++++++++++- torchtnt/framework/callbacks/dcp_saver.py | 12 +- 2 files changed, 250 insertions(+), 3 deletions(-) diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index 568b51edb5..6b443847e5 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -12,13 +12,15 @@ import shutil import tempfile import unittest -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional, Tuple from unittest import mock from unittest.mock import MagicMock, patch import torch +from pyre_extensions import none_throws from torch import nn from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, DefaultSavePlanner, @@ -28,16 +30,24 @@ from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq from torchtnt.framework._test_utils import ( DummyAutoUnit, + DummyEvalUnit, + DummyMeanMetric, DummyMultiOptimUnit, + DummyPredictUnit, DummyTrainUnit, + generate_dummy_stateful_dataloader, generate_random_dataloader, get_dummy_train_state, ) from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver +from torchtnt.framework.evaluate import evaluate +from torchtnt.framework.fit import fit +from torchtnt.framework.predict import predict from torchtnt.framework.state import State from torchtnt.framework.train import train +from torchtnt.utils.checkpoint import get_latest_checkpoint_path from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import seed from torchtnt.utils.test_utils import skip_if_not_distributed @@ -490,6 +500,235 @@ def test_save_restore_multi_optimizers(self) -> None: my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim) dcp_cb.restore_from_latest(temp_dir, my_unit_clone) + def test_save_predict(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + + my_unit = DummyPredictUnit(input_dim=input_dim) + + # pyre-ignore[16]: Add new attribute for testing + my_unit.output_mean = DummyMeanMetric() + + # pyre-ignore[16]: Add at least one element to the metric + my_unit.output_mean.update(1.0) + + dataloader = generate_dummy_stateful_dataloader( + dataset_len, input_dim, batch_size + ) + + with tempfile.TemporaryDirectory() as temp_dir: + dcp_cb = DistributedCheckpointSaver( + temp_dir, + knob_options=KnobOptions(1), + save_every_n_predict_steps=2, + ) + + predict(my_unit, dataloader, callbacks=[dcp_cb]) + + generated_ckpts = os.listdir(temp_dir) + expected_ckpts = [ + "epoch_0_predict_step_2", + "epoch_0_predict_step_4", + ] + + self.assertCountEqual(generated_ckpts, expected_ckpts) + + ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir)) + self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])) + + storage_reader = FsspecReader(ckpt_path) + metadata = storage_reader.read_metadata() + self.assertCountEqual( + # Get base keys after the app_state wrapper + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, + [ + "predict_progress", + "predict_dataloader", + "output_mean", + ], + ) + + def test_save_evaluate(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + + my_unit = DummyEvalUnit(input_dim=input_dim) + + dataloader = generate_dummy_stateful_dataloader( + dataset_len, input_dim, batch_size + ) + + with tempfile.TemporaryDirectory() as temp_dir: + dcp_cb = DistributedCheckpointSaver( + temp_dir, + knob_options=KnobOptions(1), + save_every_n_eval_steps=2, + ) + + evaluate(my_unit, dataloader, callbacks=[dcp_cb]) + + generated_ckpts = os.listdir(temp_dir) + expected_ckpts = [ + "epoch_0_eval_step_2", + "epoch_0_eval_step_4", + ] + + self.assertCountEqual(generated_ckpts, expected_ckpts) + + ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir)) + self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])) + + storage_reader = FsspecReader(ckpt_path) + metadata = storage_reader.read_metadata() + self.assertCountEqual( + # Get base keys after the app_state wrapper + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, + [ + "eval_progress", + "eval_dataloader", + ], + ) + + def test_save_fit_eval_every_n_epochs(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + + my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2)) + my_unit.output_mean = DummyMeanMetric() + + train_dataloader = generate_dummy_stateful_dataloader( + dataset_len, input_dim, batch_size + ) + + eval_dataloader = generate_dummy_stateful_dataloader( + dataset_len, input_dim, batch_size + ) + + with tempfile.TemporaryDirectory() as temp_dir: + dcp_cb = DistributedCheckpointSaver( + temp_dir, + knob_options=KnobOptions(1), + save_every_n_train_steps=2, + save_every_n_eval_steps=2, + ) + + fit( + my_unit, + max_epochs=1, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + evaluate_every_n_epochs=1, + callbacks=[dcp_cb], + ) + + generated_ckpts = os.listdir(temp_dir) + expected_ckpts = [ + "epoch_0_train_step_2_eval_step_0", + "epoch_0_train_step_4_eval_step_0", + "epoch_1_train_step_5_eval_step_2", + "epoch_1_train_step_5_eval_step_4", + ] + self.assertCountEqual(generated_ckpts, expected_ckpts) + + expected_dataloader = ["train_dataloader"] * 2 + ["eval_dataloader"] * 2 + for ckpt_path, dl_key in zip(expected_ckpts, expected_dataloader): + storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path)) + metadata = storage_reader.read_metadata() + self.assertCountEqual( + # Get base keys after the app_state wrapper + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, + [ + "module", # Both train and eval checkpoints save full app_state in fit + "optimizer", + "lr_scheduler", + "train_progress", + "eval_progress", + "predict_progress", # included because of AutoUnit + dl_key, + "output_mean", + ], + ) + + def test_save_fit_eval_every_n_steps(self) -> None: + input_dim = 2 + + my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2)) + my_unit.output_mean = DummyMeanMetric() + + train_dataloader = generate_dummy_stateful_dataloader(10, input_dim, 2) + eval_dataloader = generate_dummy_stateful_dataloader(8, input_dim, 2) + + with tempfile.TemporaryDirectory() as temp_dir: + dcp_cb = DistributedCheckpointSaver( + temp_dir, + knob_options=KnobOptions(1), + save_every_n_train_steps=2, + save_every_n_eval_steps=2, + ) + + fit( + my_unit, + max_epochs=1, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + evaluate_every_n_steps=2, + evaluate_every_n_epochs=None, + callbacks=[dcp_cb], + ) + + generated_ckpts = os.listdir(temp_dir) + expected_ckpts_to_dl_mapping: Dict[str, Tuple[str, ...]] = { + # First train 2 steps + "epoch_0_train_step_2_eval_step_0": ("train_dataloader",), + # Then do a whole evaluation (4 steps) + "epoch_0_train_step_2_eval_step_2": ( + "train_dataloader", + "eval_dataloader", + ), + "epoch_0_train_step_2_eval_step_4": ( + "train_dataloader", + "eval_dataloader", + ), + # Then train other two steps + "epoch_0_train_step_4_eval_step_4": ("train_dataloader",), + # Finally do a whole evaluation (4 steps) + "epoch_0_train_step_4_eval_step_6": ( + "train_dataloader", + "eval_dataloader", + ), + "epoch_0_train_step_4_eval_step_8": ( + "train_dataloader", + "eval_dataloader", + ), + # Last checkpoint (on_train_end) + "epoch_1_train_step_5_eval_step_8": (), + } + self.assertCountEqual( + generated_ckpts, [*expected_ckpts_to_dl_mapping.keys()] + ) + + for ckpt_path, expected_dls in expected_ckpts_to_dl_mapping.items(): + print(f"Checking {ckpt_path}") + storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path)) + metadata = storage_reader.read_metadata() + self.assertCountEqual( + # Get base keys after the app_state wrapper + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, + [ + "module", # Both train and eval checkpoints save full app_state in fit + "optimizer", + "lr_scheduler", + "train_progress", + "eval_progress", + "predict_progress", # included because of AutoUnit + "output_mean", + *expected_dls, + ], + ) + class DummyStatefulDataLoader: def __init__(self, dataloader: DataLoader) -> None: diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 7cebc7f5a2..5bf03b587d 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -65,8 +65,10 @@ class DistributedCheckpointSaver(BaseCheckpointer): Args: dirpath: Parent directory to save snapshots to. save_every_n_train_steps: Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated. - save_every_n_epochs: Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated. + save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated. save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit. + save_every_n_eval_steps: Frequency of evaluation steps with which to save checkpoints during training. Use this if wanting to save checkpoints during evaluate. + save_every_n_predict_steps: Frequency of prediction steps with which to save checkpoints during training. Use this if wanting to save checkpoints during using predict entrypoint. keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) @@ -93,7 +95,9 @@ def __init__( *, save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, + save_every_n_eval_steps: Optional[int] = None, save_every_n_eval_epochs: Optional[int] = None, + save_every_n_predict_steps: Optional[int] = None, keep_last_n_checkpoints: Optional[int] = None, best_checkpoint_config: Optional[BestCheckpointConfig] = None, process_group: Optional[dist.ProcessGroup] = None, @@ -104,7 +108,9 @@ def __init__( dirpath=dirpath, save_every_n_train_steps=save_every_n_train_steps, save_every_n_epochs=save_every_n_epochs, + save_every_n_eval_steps=save_every_n_eval_steps, save_every_n_eval_epochs=save_every_n_eval_epochs, + save_every_n_predict_steps=save_every_n_predict_steps, keep_last_n_checkpoints=keep_last_n_checkpoints, best_checkpoint_config=best_checkpoint_config, process_group=process_group, @@ -129,10 +135,12 @@ def _checkpoint_impl( "on_train_epoch_end", "on_train_end", "on_eval_epoch_end", + "on_eval_step_end", + "on_predict_step_end", ]: raise RuntimeError(f"Unexpected hook encountered '{hook}'") - intra_epoch = hook == "on_train_step_end" + intra_epoch = "step_end" in hook curr_snapshot_wait = hook == "on_train_end" if planner is None: From c0d8c001bbe0b21fd08cc9a9f4a94073979d449f Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Fri, 4 Oct 2024 14:47:56 -0700 Subject: [PATCH 5/5] Add eval/predict dataloader parameters to restore methods (#917) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/917 Differential Revision: D63013005 --- .../framework/callbacks/test_base_checkpointer.py | 14 ++++++++++++-- torchtnt/framework/callbacks/base_checkpointer.py | 5 ++++- torchtnt/framework/callbacks/dcp_saver.py | 14 +++++++++++++- .../framework/callbacks/torchsnapshot_saver.py | 3 ++- 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index c94033a0db..e43ad50531 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -14,7 +14,7 @@ import tempfile import time import unittest -from typing import cast, Iterable, List, Optional +from typing import Any, cast, Iterable, List, Optional from unittest.mock import MagicMock, patch import torch @@ -42,7 +42,14 @@ from torchtnt.framework.state import ActivePhase, State from torchtnt.framework.train import train -from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData, TTrainUnit +from torchtnt.framework.unit import ( + AppStateMixin, + TEvalData, + TPredictData, + TrainUnit, + TTrainData, + TTrainUnit, +) from torchtnt.utils.checkpoint import BestCheckpointConfig, get_latest_checkpoint_path from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import init_from_env @@ -94,10 +101,13 @@ def restore( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, + eval_dataloader: Optional[Iterable[TEvalData]] = None, + predict_dataloader: Optional[Iterable[TPredictData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, msg: str = "", restored_checkpoint_path: Optional[List[str]] = None, + **kwargs: Any, ) -> None: if restored_checkpoint_path is not None: if len(restored_checkpoint_path): diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index a411e80163..ddf02c420e 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -410,6 +410,7 @@ def restore( train_dataloader: Optional[Iterable[TTrainData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, + **kwargs: Any, ) -> None: """Method to restore checkpoint state from a path. @@ -419,7 +420,7 @@ def restore( Args: path: Path of the checkpoint to restore. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. - train_dataloader: An optional train dataloader to restore. + train_dataloader: An optional train dataloader to restore. Can only be used when restoring from a train or fit checkpoint. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) restore_options: Controls what to filter when restoring the state. """ @@ -538,6 +539,7 @@ def restore_with_id( train_dataloader: Optional[Iterable[TTrainData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, + **kwargs: Any, ) -> None: """Method to restore checkpoint state from a checkpoint id. @@ -561,4 +563,5 @@ def restore_with_id( train_dataloader=train_dataloader, process_group=process_group, restore_options=restore_options, + **kwargs, ) diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 5bf03b587d..c4ce32f192 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -37,7 +37,9 @@ from torchtnt.framework.state import State from torchtnt.framework.unit import ( AppStateMixin, + TEvalData, TEvalUnit, + TPredictData, TPredictUnit, TTrainData, TTrainUnit, @@ -228,11 +230,14 @@ def restore( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, + eval_dataloader: Optional[Iterable[TEvalData]] = None, + predict_dataloader: Optional[Iterable[TPredictData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, knob_options: Optional[KnobOptions] = None, planner: Optional[LoadPlanner] = None, storage_reader: Optional[StorageReader] = None, + **kwargs: Any, ) -> None: """Utility method to restore dcp checkpoint from a path.""" @@ -242,6 +247,8 @@ def restore( checkpoint_id, unit, train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + predict_dataloader=predict_dataloader, process_group=process_group, restore_options=restore_options, knob_options=knob_options, @@ -255,11 +262,14 @@ def restore_with_id( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, + eval_dataloader: Optional[Iterable[TEvalData]] = None, + predict_dataloader: Optional[Iterable[TPredictData]] = None, process_group: Optional[dist.ProcessGroup] = None, restore_options: Optional[RestoreOptions] = None, knob_options: Optional[KnobOptions] = None, planner: Optional[LoadPlanner] = None, storage_reader: Optional[StorageReader] = None, + **kwargs: Any, ) -> None: """Utility method to restore dcp checkpoint from a checkpoint_id. @@ -269,7 +279,9 @@ def restore_with_id( Args: checkpoint_id: Checkpoint id. It can be the path of the snapshot to restore. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. - train_dataloader: An optional train dataloader to restore. + train_dataloader: An optional train dataloader to restore. Can only be used when restoring from a train or fit checkpoint. + eval_dataloader: An optional eval dataloader to restore. Can only be used when restoring from an eval or fit checkpoint. + predict_dataloader: An optional predict dataloader to restore. Can only be used when restoring from a predict checkpoint. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) If not Gloo, a Gloo process group is created. Note: If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion. diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 363b2e772d..713324d812 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -270,6 +270,7 @@ def restore( storage_options: Optional[Dict[str, Any]] = None, knob_options: Optional[KnobOptions] = None, strict: bool = True, + **kwargs: Any, ) -> None: """Utility method to restore snapshot state from a path. @@ -279,7 +280,7 @@ def restore( Args: path: Path of the snapshot to restore. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. - train_dataloader: An optional train dataloader to restore. + train_dataloader: An optional train dataloader to restore. Note that restoring from predict or evaluate dataloaders is not supported for TorchSnapshotSaver. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) restore_options: Controls what to filter when restoring the state. storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations.