Skip to content

Commit

Permalink
Conditionally include also train_dataloader in fit-eval checkpoints (#…
Browse files Browse the repository at this point in the history
…918)

Summary: Pull Request resolved: #918

Differential Revision: D63919225
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Oct 4, 2024
1 parent 8bbeb20 commit 7d51490
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
28 changes: 28 additions & 0 deletions tests/framework/callbacks/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyEvalUnit,
DummyFitUnit,
DummyMeanMetric,
DummyTrainUnit,
generate_dummy_stateful_dataloader,
Expand Down Expand Up @@ -85,6 +86,33 @@ def test_get_app_state(self) -> None:
],
)

# Test evaluate intra-epoch within train epoch on FIT (evaluate_every_n_steps)
my_unit = DummyFitUnit(input_dim=2)
my_unit.train_progress.increment_step() # Simulate at least one step in each phase
my_unit.eval_progress.increment_step()

state = get_dummy_fit_state()
state._active_phase = ActivePhase.EVALUATE

train_dl = generate_dummy_stateful_dataloader(1, 1, 1)
eval_dl = generate_dummy_stateful_dataloader(1, 1, 1)
none_throws(state.train_state)._dataloader = train_dl
none_throws(state.eval_state)._dataloader = eval_dl

app_state = _prepare_app_state_for_checkpoint(state, my_unit, intra_epoch=True)
self.assertCountEqual(
app_state.keys(),
[
"module",
"optimizer",
"loss_fn",
"train_progress",
"eval_progress",
"train_dataloader",
"eval_dataloader",
],
)

def test_get_step_phase_mapping(self) -> None:
unit = DummyAutoUnit(module=nn.Linear(2, 2))
unit.train_progress._num_steps_completed = 5
Expand Down
30 changes: 23 additions & 7 deletions torchtnt/framework/callbacks/_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

from typing import Any, cast, Dict, Union

from pyre_extensions import none_throws

from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.state import ActivePhase, EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.checkpoint import Phase

Expand Down Expand Up @@ -123,13 +125,27 @@ def _prepare_app_state_for_checkpoint(
remove_lr_schedulers=True,
)

if not intra_epoch:
return app_state

# for intra-epoch checkpointing, include dataloader state of the current phase
phase_dl = state.active_phase_state().dataloader
if intra_epoch and isinstance(phase_dl, Stateful):
dataloader_state_key = _PHASE_DL_STATE_KEY_MAPPING[
state.active_phase.into_phase()
]
app_state[dataloader_state_key] = phase_dl
active_dataloaders = {state.active_phase: state.active_phase_state().dataloader}

# Special case for FIT where eval is executed every n steps. We also need to save
# the train dataloader state. In this case, train epoch wouldn't be incremented yet.
if (
state.entry_point == EntryPoint.FIT
and state.active_phase == ActivePhase.EVALUATE
and cast(TTrainUnit, unit).train_progress.num_steps_completed_in_epoch != 0
):
active_dataloaders[ActivePhase.TRAIN] = none_throws(
state.train_state
).dataloader

for active_phase, dl in active_dataloaders.items():
if isinstance(dl, Stateful):
dl_key = _PHASE_DL_STATE_KEY_MAPPING[active_phase.into_phase()]
app_state[dl_key] = dl

return app_state

Expand Down

0 comments on commit 7d51490

Please sign in to comment.