Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eval/predict dataloader parameters to restore methods #917

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tempfile
import time
import unittest
from typing import cast, Iterable, List, Optional
from typing import Any, cast, Iterable, List, Optional
from unittest.mock import MagicMock, patch

import torch
Expand All @@ -25,6 +25,7 @@
Batch,
DummyAutoUnit,
DummyFitUnit,
DummyPredictUnit,
DummyTrainUnit,
generate_random_dataloader,
get_dummy_fit_state,
Expand All @@ -35,11 +36,20 @@
)
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.callbacks.lambda_callback import Lambda
from torchtnt.framework.evaluate import evaluate
from torchtnt.framework.fit import fit
from torchtnt.framework.predict import predict
from torchtnt.framework.state import ActivePhase, State

from torchtnt.framework.train import train
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData, TTrainUnit
from torchtnt.framework.unit import (
AppStateMixin,
TEvalData,
TPredictData,
TrainUnit,
TTrainData,
TTrainUnit,
)
from torchtnt.utils.checkpoint import BestCheckpointConfig, get_latest_checkpoint_path
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import init_from_env
Expand All @@ -57,7 +67,9 @@ def __init__(
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_steps: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
save_every_n_predict_steps: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
Expand All @@ -66,7 +78,9 @@ def __init__(
dirpath,
save_every_n_train_steps=save_every_n_train_steps,
save_every_n_epochs=save_every_n_epochs,
save_every_n_eval_steps=save_every_n_eval_steps,
save_every_n_eval_epochs=save_every_n_eval_epochs,
save_every_n_predict_steps=save_every_n_predict_steps,
keep_last_n_checkpoints=keep_last_n_checkpoints,
best_checkpoint_config=best_checkpoint_config,
process_group=process_group,
Expand All @@ -87,10 +101,13 @@ def restore(
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
eval_dataloader: Optional[Iterable[TEvalData]] = None,
predict_dataloader: Optional[Iterable[TPredictData]] = None,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
msg: str = "",
restored_checkpoint_path: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
if restored_checkpoint_path is not None:
if len(restored_checkpoint_path):
Expand Down Expand Up @@ -243,6 +260,83 @@ def test_save_fit_entrypoint(self) -> None:
checkpointer._latest_checkpoint_path,
)

@patch.object(BaseCheckpointSaver, "_checkpoint_impl")
def test_save_eval_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None:
my_unit = DummyFitUnit(input_dim=2)
with tempfile.TemporaryDirectory() as temp_dir:
checkpointer = BaseCheckpointSaver(
temp_dir,
save_every_n_eval_steps=2,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="val_loss", mode="min"
),
keep_last_n_checkpoints=1,
)

ckpt_container: List[str] = []

def _checkpoint_impl_side_effect(
state: State, unit: AppStateMixin, checkpoint_id: str, hook: str
) -> bool:
ckpt_container.append(checkpoint_id)
return True

mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect

eval_dataloader = generate_random_dataloader(10, 2, 1)

warning_container: List[str] = []
with patch(
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.warning",
side_effect=warning_container.append,
):
evaluate(my_unit, eval_dataloader, callbacks=[checkpointer])

# Verify that checkpoint optimality tracking was disabled
self.assertIn(
"Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints.",
warning_container,
)
self.assertIn(
"Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints.",
warning_container,
)

# Make sure that the correct checkpoints were saved, without tracked metrics
expected_ckpts = [
f"{temp_dir}/epoch_0_eval_step_{i*2}" for i in range(1, 6)
]
self.assertEqual(ckpt_container, expected_ckpts)

@patch.object(BaseCheckpointSaver, "_checkpoint_impl")
def test_save_predict_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None:
my_unit = DummyPredictUnit(input_dim=2)
with tempfile.TemporaryDirectory() as temp_dir:
checkpointer = BaseCheckpointSaver(
temp_dir,
save_every_n_predict_steps=1,
)

ckpt_container: List[str] = []

def _checkpoint_impl_side_effect(
state: State, unit: AppStateMixin, checkpoint_id: str, hook: str
) -> bool:
ckpt_container.append(checkpoint_id)
return True

mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect

predict_dataloader = generate_random_dataloader(10, 2, 1)

predict(my_unit, predict_dataloader, callbacks=[checkpointer])

# Make sure that the correct checkpoints were saved
expected_ckpts = [
f"{temp_dir}/epoch_0_predict_step_{i}" for i in range(1, 11)
]
self.assertEqual(ckpt_container, expected_ckpts)

@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_restore_from_latest(self, mock_stdout: MagicMock) -> None:
my_unit = DummyTrainUnit(input_dim=2)
Expand Down
28 changes: 28 additions & 0 deletions tests/framework/callbacks/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyEvalUnit,
DummyFitUnit,
DummyMeanMetric,
DummyTrainUnit,
generate_dummy_stateful_dataloader,
Expand Down Expand Up @@ -85,6 +86,33 @@ def test_get_app_state(self) -> None:
],
)

# Test evaluate intra-epoch within train epoch on FIT (evaluate_every_n_steps)
my_unit = DummyFitUnit(input_dim=2)
my_unit.train_progress.increment_step() # Simulate at least one step in each phase
my_unit.eval_progress.increment_step()

state = get_dummy_fit_state()
state._active_phase = ActivePhase.EVALUATE

train_dl = generate_dummy_stateful_dataloader(1, 1, 1)
eval_dl = generate_dummy_stateful_dataloader(1, 1, 1)
none_throws(state.train_state)._dataloader = train_dl
none_throws(state.eval_state)._dataloader = eval_dl

app_state = _prepare_app_state_for_checkpoint(state, my_unit, intra_epoch=True)
self.assertCountEqual(
app_state.keys(),
[
"module",
"optimizer",
"loss_fn",
"train_progress",
"eval_progress",
"train_dataloader",
"eval_dataloader",
],
)

def test_get_step_phase_mapping(self) -> None:
unit = DummyAutoUnit(module=nn.Linear(2, 2))
unit.train_progress._num_steps_completed = 5
Expand Down
Loading
Loading