Skip to content

Commit

Permalink
Generate predict/evaluate ckpts in DCP Saver (pytorch#915)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#915

Differential Revision: D63712524
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Oct 4, 2024
1 parent fa2d7c1 commit ad5a219
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 1 deletion.
150 changes: 150 additions & 0 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from unittest.mock import MagicMock, patch

import torch
from pyre_extensions import none_throws
from torch import nn
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
Expand All @@ -28,16 +30,24 @@
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyEvalUnit,
DummyMeanMetric,
DummyMultiOptimUnit,
DummyPredictUnit,
DummyTrainUnit,
generate_dummy_stateful_dataloader,
generate_random_dataloader,
get_dummy_train_state,
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
from torchtnt.framework.evaluate import evaluate
from torchtnt.framework.fit import fit
from torchtnt.framework.predict import predict

from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.utils.checkpoint import get_latest_checkpoint_path
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import skip_if_not_distributed
Expand Down Expand Up @@ -490,6 +500,146 @@ def test_save_restore_multi_optimizers(self) -> None:
my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim)
dcp_cb.restore_from_latest(temp_dir, my_unit_clone)

def test_save_predict(self) -> None:
input_dim = 2
dataset_len = 10
batch_size = 2

my_unit = DummyPredictUnit(input_dim=input_dim)

# pyre-ignore[16]: Add new attribute for testing
my_unit.output_mean = DummyMeanMetric()

# pyre-ignore[16]: Add at least one element to the metric
my_unit.output_mean.update(1.0)

dataloader = generate_dummy_stateful_dataloader(
dataset_len, input_dim, batch_size
)

with tempfile.TemporaryDirectory() as temp_dir:
dcp_cb = DistributedCheckpointSaver(
temp_dir,
knob_options=KnobOptions(1),
save_every_n_predict_steps=2,
)

predict(my_unit, dataloader, callbacks=[dcp_cb])

generated_ckpts = os.listdir(temp_dir)
expected_ckpts = [
"epoch_0_predict_step_2",
"epoch_0_predict_step_4",
]

self.assertCountEqual(generated_ckpts, expected_ckpts)

ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]))

storage_reader = FsspecReader(ckpt_path)
metadata = storage_reader.read_metadata()
self.assertCountEqual(
# Get base keys after the app_state wrapper
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
[
"predict_progress",
"predict_dataloader",
"output_mean",
],
)

def test_save_evaluate(self) -> None:
input_dim = 2
dataset_len = 10
batch_size = 2

my_unit = DummyEvalUnit(input_dim=input_dim)

dataloader = generate_dummy_stateful_dataloader(
dataset_len, input_dim, batch_size
)

with tempfile.TemporaryDirectory() as temp_dir:
dcp_cb = DistributedCheckpointSaver(
temp_dir,
knob_options=KnobOptions(1),
save_every_n_eval_steps=2,
)

evaluate(my_unit, dataloader, callbacks=[dcp_cb])

generated_ckpts = os.listdir(temp_dir)
expected_ckpts = [
"epoch_0_eval_step_2",
"epoch_0_eval_step_4",
]

self.assertCountEqual(generated_ckpts, expected_ckpts)

ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]))

def test_save_fit(self) -> None:
input_dim = 2
dataset_len = 10
batch_size = 2

my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2))
my_unit.output_mean = DummyMeanMetric()

train_dataloader = generate_dummy_stateful_dataloader(
dataset_len, input_dim, batch_size
)

eval_dataloader = generate_dummy_stateful_dataloader(
dataset_len, input_dim, batch_size
)

with tempfile.TemporaryDirectory() as temp_dir:
dcp_cb = DistributedCheckpointSaver(
temp_dir,
knob_options=KnobOptions(1),
save_every_n_train_steps=2,
save_every_n_eval_steps=2,
)

fit(
my_unit,
max_epochs=1,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
callbacks=[dcp_cb],
)

generated_ckpts = os.listdir(temp_dir)
expected_ckpts = [
"epoch_0_train_step_2_eval_step_0",
"epoch_0_train_step_4_eval_step_0",
"epoch_1_train_step_5_eval_step_2",
"epoch_1_train_step_5_eval_step_4",
]
self.assertCountEqual(generated_ckpts, expected_ckpts)

expected_dataloader = ["train_dataloader"] * 2 + ["eval_dataloader"] * 2
for ckpt_path, dl_key in zip(generated_ckpts, expected_dataloader):
storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path))
metadata = storage_reader.read_metadata()
self.assertCountEqual(
# Get base keys after the app_state wrapper
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
[
"module", # Both train and eval checkpoints save full app_state in fit
"optimizer",
"lr_scheduler",
"train_progress",
"eval_progress",
"predict_progress", # included because of AutoUnit
dl_key,
"output_mean",
],
)


class DummyStatefulDataLoader:
def __init__(self, dataloader: DataLoader) -> None:
Expand Down
8 changes: 7 additions & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def __init__(
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_steps: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
save_every_n_predict_steps: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
Expand All @@ -104,7 +106,9 @@ def __init__(
dirpath=dirpath,
save_every_n_train_steps=save_every_n_train_steps,
save_every_n_epochs=save_every_n_epochs,
save_every_n_eval_steps=save_every_n_eval_steps,
save_every_n_eval_epochs=save_every_n_eval_epochs,
save_every_n_predict_steps=save_every_n_predict_steps,
keep_last_n_checkpoints=keep_last_n_checkpoints,
best_checkpoint_config=best_checkpoint_config,
process_group=process_group,
Expand All @@ -129,10 +133,12 @@ def _checkpoint_impl(
"on_train_epoch_end",
"on_train_end",
"on_eval_epoch_end",
"on_eval_step_end",
"on_predict_step_end",
]:
raise RuntimeError(f"Unexpected hook encountered '{hook}'")

intra_epoch = hook == "on_train_step_end"
intra_epoch = "step_end" in hook
curr_snapshot_wait = hook == "on_train_end"

if planner is None:
Expand Down

0 comments on commit ad5a219

Please sign in to comment.