Skip to content

Commit

Permalink
Add on_eval_epoch_end as a valid hook to TorchSnapshotSaver
Browse files Browse the repository at this point in the history
Summary:
When `TorchSnapshotSaver` is used with `save_every_n_eval_epochs > 0` and `best_checkpoint_config`, this hook is invoked.

https://www.internalfb.com/code/fbsource/[a8a4a7fba9a8a93af7382fa12e669c066f41024f]/fbcode/torchtnt/framework/callbacks/base_checkpointer.py?lines=273

However, it fails due to not being considered a valid hook. This diff fixes that

Differential Revision: D57083777
  • Loading branch information
dorukhansergin authored and facebook-github-bot committed May 8, 2024
1 parent c62c630 commit d768930
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def _checkpoint_impl(
"""
Checkpoint the current state of the application.
"""
if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]:
if hook not in [
"on_train_step_end",
"on_train_epoch_end",
"on_train_end",
"on_eval_epoch_end",
]:
raise RuntimeError(f"Unexpected hook encountered '{hook}'")

intra_epoch = False
Expand Down

0 comments on commit d768930

Please sign in to comment.