Skip to content

Commit

Permalink
remove init_optim_state in dcp checkpointer (#901)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #901

Reviewed By: diego-urgell

Differential Revision: D59661542

fbshipit-source-id: a25d5fb991da45187f6a56f4423cb0adad8afe7b
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Sep 17, 2024
1 parent 57a4279 commit 843835c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 12 deletions.
19 changes: 19 additions & 0 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyMultiOptimUnit,
DummyTrainUnit,
generate_random_dataloader,
get_dummy_train_state,
Expand Down Expand Up @@ -471,6 +472,24 @@ def test_gloo_pg_restore(
self.assertEqual(process_group, None)
mock_destroy_process_group.assert_not_called()

def test_save_restore_multi_optimizers(self) -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 1

my_unit = DummyMultiOptimUnit(input_dim=input_dim)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
with tempfile.TemporaryDirectory() as temp_dir:
dcp_cb = DistributedCheckpointSaver(
temp_dir,
knob_options=KnobOptions(1),
)
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])

my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim)
dcp_cb.restore_from_latest(temp_dir, my_unit_clone)


class DummyStatefulDataLoader:
def __init__(self, dataloader: DataLoader) -> None:
Expand Down
38 changes: 37 additions & 1 deletion torchtnt/framework/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Iterable, Iterator, Optional, Tuple
from typing import Iterable, Iterator, List, Optional, Tuple

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -116,6 +116,42 @@ def train_step(
return loss, outputs


class DummyMultiOptimUnit(TrainUnit[Batch]):
def __init__(self, input_dim: int) -> None:
super().__init__()
# initialize module, loss_fn, & optimizer

self.modules: List[nn.Module] = [nn.Linear(input_dim, 2) for _ in range(6)]
self.loss_fn = nn.CrossEntropyLoss()
self.optims = [
torch.optim.SGD,
torch.optim.Adam,
torch.optim.AdamW,
torch.optim.Adadelta,
torch.optim.NAdam,
torch.optim.RMSprop,
]
self.applied_optims: List[torch.optim.Optimizer] = []
for module, optim in zip(self.modules, self.optims):
self.applied_optims.append(optim(module.parameters(), lr=0.1))

def train_step(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs, targets = data

outputs = [module(inputs) for module in self.modules]
losses = [self.loss_fn(output, targets) for output in outputs]
loss = torch.stack(losses).sum()
loss.backward()

for optim in self.applied_optims:
optim.step()
optim.zero_grad()

return loss, outputs[0]


class DummyFitUnit(TrainUnit[Batch], EvalUnit[Batch]):
def __init__(self, input_dim: int) -> None:
super().__init__()
Expand Down
11 changes: 0 additions & 11 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from datetime import timedelta
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as dist
from pyre_extensions import none_throws
from torch.distributed import checkpoint as dcp
Expand Down Expand Up @@ -46,7 +45,6 @@
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful, Stateful

Expand Down Expand Up @@ -323,15 +321,6 @@ def restore_with_id(
"train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
)

# necessary for loading optimizers since states are initialized lazy
for obj in app_state.values():
# sometimes optimizers are actually held in a wrapper which handles calling
# state_dict and load_state_dict, sa is the case for
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
optimizer = getattr(obj, "optimizer", obj)
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

dcp.load(
{"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
Expand Down

0 comments on commit 843835c

Please sign in to comment.