diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index 8f18e6fd6b..568b51edb5 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -28,6 +28,7 @@ from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq from torchtnt.framework._test_utils import ( DummyAutoUnit, + DummyMultiOptimUnit, DummyTrainUnit, generate_random_dataloader, get_dummy_train_state, @@ -471,6 +472,24 @@ def test_gloo_pg_restore( self.assertEqual(process_group, None) mock_destroy_process_group.assert_not_called() + def test_save_restore_multi_optimizers(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + max_epochs = 1 + + my_unit = DummyMultiOptimUnit(input_dim=input_dim) + dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) + with tempfile.TemporaryDirectory() as temp_dir: + dcp_cb = DistributedCheckpointSaver( + temp_dir, + knob_options=KnobOptions(1), + ) + train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb]) + + my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim) + dcp_cb.restore_from_latest(temp_dir, my_unit_clone) + class DummyStatefulDataLoader: def __init__(self, dataloader: DataLoader) -> None: diff --git a/torchtnt/framework/_test_utils.py b/torchtnt/framework/_test_utils.py index 15816d946e..45d1100a9f 100644 --- a/torchtnt/framework/_test_utils.py +++ b/torchtnt/framework/_test_utils.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Iterable, Iterator, Optional, Tuple +from typing import Iterable, Iterator, List, Optional, Tuple import torch from torch import nn, Tensor @@ -116,6 +116,42 @@ def train_step( return loss, outputs +class DummyMultiOptimUnit(TrainUnit[Batch]): + def __init__(self, input_dim: int) -> None: + super().__init__() + # initialize module, loss_fn, & optimizer + + self.modules: List[nn.Module] = [nn.Linear(input_dim, 2) for _ in range(6)] + self.loss_fn = nn.CrossEntropyLoss() + self.optims = [ + torch.optim.SGD, + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.Adadelta, + torch.optim.NAdam, + torch.optim.RMSprop, + ] + self.applied_optims: List[torch.optim.Optimizer] = [] + for module, optim in zip(self.modules, self.optims): + self.applied_optims.append(optim(module.parameters(), lr=0.1)) + + def train_step( + self, state: State, data: Batch + ) -> Tuple[torch.Tensor, torch.Tensor]: + inputs, targets = data + + outputs = [module(inputs) for module in self.modules] + losses = [self.loss_fn(output, targets) for output in outputs] + loss = torch.stack(losses).sum() + loss.backward() + + for optim in self.applied_optims: + optim.step() + optim.zero_grad() + + return loss, outputs[0] + + class DummyFitUnit(TrainUnit[Batch], EvalUnit[Batch]): def __init__(self, input_dim: int) -> None: super().__init__() diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index f091e8de94..b1b4a232a5 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -12,7 +12,6 @@ from datetime import timedelta from typing import Any, Dict, Iterable, List, Optional, Union -import torch import torch.distributed as dist from pyre_extensions import none_throws from torch.distributed import checkpoint as dcp @@ -46,7 +45,6 @@ ) from torchtnt.framework.utils import get_timing_context from torchtnt.utils.checkpoint import BestCheckpointConfig -from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import MultiStateful, Stateful @@ -323,15 +321,6 @@ def restore_with_id( "train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot" ) - # necessary for loading optimizers since states are initialized lazy - for obj in app_state.values(): - # sometimes optimizers are actually held in a wrapper which handles calling - # state_dict and load_state_dict, sa is the case for - # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case. - optimizer = getattr(obj, "optimizer", obj) - if isinstance(optimizer, torch.optim.Optimizer): - init_optim_state(optimizer) - dcp.load( {"app_state": MultiStateful(app_state)}, checkpoint_id=checkpoint_id,