diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index 568b51edb5..39b4dfac01 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -17,8 +17,10 @@ from unittest.mock import MagicMock, patch import torch +from pyre_extensions import none_throws from torch import nn from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, DefaultSavePlanner, @@ -28,16 +30,24 @@ from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq from torchtnt.framework._test_utils import ( DummyAutoUnit, + DummyEvalUnit, + DummyMeanMetric, DummyMultiOptimUnit, + DummyPredictUnit, DummyTrainUnit, + generate_dummy_stateful_dataloader, generate_random_dataloader, get_dummy_train_state, ) from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver +from torchtnt.framework.evaluate import evaluate +from torchtnt.framework.fit import fit +from torchtnt.framework.predict import predict from torchtnt.framework.state import State from torchtnt.framework.train import train +from torchtnt.utils.checkpoint import get_latest_checkpoint_path from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import seed from torchtnt.utils.test_utils import skip_if_not_distributed @@ -490,6 +500,146 @@ def test_save_restore_multi_optimizers(self) -> None: my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim) dcp_cb.restore_from_latest(temp_dir, my_unit_clone) + def test_save_predict(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + + my_unit = DummyPredictUnit(input_dim=input_dim) + + # pyre-ignore[16]: Add new attribute for testing + my_unit.output_mean = DummyMeanMetric() + + # pyre-ignore[16]: Add at least one element to the metric + my_unit.output_mean.update(1.0) + + dataloader = generate_dummy_stateful_dataloader( + dataset_len, input_dim, batch_size + ) + + with tempfile.TemporaryDirectory() as temp_dir: + dcp_cb = DistributedCheckpointSaver( + temp_dir, + knob_options=KnobOptions(1), + save_every_n_predict_steps=2, + ) + + predict(my_unit, dataloader, callbacks=[dcp_cb]) + + generated_ckpts = os.listdir(temp_dir) + expected_ckpts = [ + "epoch_0_predict_step_2", + "epoch_0_predict_step_4", + ] + + self.assertCountEqual(generated_ckpts, expected_ckpts) + + ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir)) + self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])) + + storage_reader = FsspecReader(ckpt_path) + metadata = storage_reader.read_metadata() + self.assertCountEqual( + # Get base keys after the app_state wrapper + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, + [ + "predict_progress", + "predict_dataloader", + "output_mean", + ], + ) + + def test_save_evaluate(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + + my_unit = DummyEvalUnit(input_dim=input_dim) + + dataloader = generate_dummy_stateful_dataloader( + dataset_len, input_dim, batch_size + ) + + with tempfile.TemporaryDirectory() as temp_dir: + dcp_cb = DistributedCheckpointSaver( + temp_dir, + knob_options=KnobOptions(1), + save_every_n_eval_steps=2, + ) + + evaluate(my_unit, dataloader, callbacks=[dcp_cb]) + + generated_ckpts = os.listdir(temp_dir) + expected_ckpts = [ + "epoch_0_eval_step_2", + "epoch_0_eval_step_4", + ] + + self.assertCountEqual(generated_ckpts, expected_ckpts) + + ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir)) + self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])) + + def test_save_fit(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + + my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2)) + my_unit.output_mean = DummyMeanMetric() + + train_dataloader = generate_dummy_stateful_dataloader( + dataset_len, input_dim, batch_size + ) + + eval_dataloader = generate_dummy_stateful_dataloader( + dataset_len, input_dim, batch_size + ) + + with tempfile.TemporaryDirectory() as temp_dir: + dcp_cb = DistributedCheckpointSaver( + temp_dir, + knob_options=KnobOptions(1), + save_every_n_train_steps=2, + save_every_n_eval_steps=2, + ) + + fit( + my_unit, + max_epochs=1, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + callbacks=[dcp_cb], + ) + + generated_ckpts = os.listdir(temp_dir) + expected_ckpts = [ + "epoch_0_train_step_2_eval_step_0", + "epoch_0_train_step_4_eval_step_0", + "epoch_1_train_step_5_eval_step_2", + "epoch_1_train_step_5_eval_step_4", + ] + self.assertCountEqual(generated_ckpts, expected_ckpts) + + expected_dataloader = ["train_dataloader"] * 2 + ["eval_dataloader"] * 2 + for ckpt_path, dl_key in zip(generated_ckpts, expected_dataloader): + storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path)) + metadata = storage_reader.read_metadata() + self.assertCountEqual( + # Get base keys after the app_state wrapper + {key.split(".")[1] for key in metadata.state_dict_metadata.keys()}, + [ + "module", # Both train and eval checkpoints save full app_state in fit + "optimizer", + "lr_scheduler", + "train_progress", + "eval_progress", + "predict_progress", # included because of AutoUnit + dl_key, + "output_mean", + ], + ) + class DummyStatefulDataLoader: def __init__(self, dataloader: DataLoader) -> None: diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 7cebc7f5a2..9ee8e32a84 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -93,7 +93,9 @@ def __init__( *, save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, + save_every_n_eval_steps: Optional[int] = None, save_every_n_eval_epochs: Optional[int] = None, + save_every_n_predict_steps: Optional[int] = None, keep_last_n_checkpoints: Optional[int] = None, best_checkpoint_config: Optional[BestCheckpointConfig] = None, process_group: Optional[dist.ProcessGroup] = None, @@ -104,7 +106,9 @@ def __init__( dirpath=dirpath, save_every_n_train_steps=save_every_n_train_steps, save_every_n_epochs=save_every_n_epochs, + save_every_n_eval_steps=save_every_n_eval_steps, save_every_n_eval_epochs=save_every_n_eval_epochs, + save_every_n_predict_steps=save_every_n_predict_steps, keep_last_n_checkpoints=keep_last_n_checkpoints, best_checkpoint_config=best_checkpoint_config, process_group=process_group, @@ -129,10 +133,12 @@ def _checkpoint_impl( "on_train_epoch_end", "on_train_end", "on_eval_epoch_end", + "on_eval_step_end", + "on_predict_step_end", ]: raise RuntimeError(f"Unexpected hook encountered '{hook}'") - intra_epoch = hook == "on_train_step_end" + intra_epoch = "step_end" in hook curr_snapshot_wait = hook == "on_train_end" if planner is None: