Skip to content

Commit

Permalink
Remove setting module training mode in _train_epoch_impl (#464)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #464

We currently call `_set_module_training_mode`/`_reset_module_training_mode` in two places in train.py. We did this to support the `train_epoch` API but this was removed in (#421) so we no longer need this.

Reviewed By: anshulverma

Differential Revision: D47621624

fbshipit-source-id: 8ae44fd8ca44e87ac5ee65b877b85431413ebaf1
  • Loading branch information
daniellepintz authored and facebook-github-bot committed Jul 20, 2023
1 parent 486bc9c commit 5c9fe83
Showing 1 changed file with 0 additions and 10 deletions.
10 changes: 0 additions & 10 deletions torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,6 @@ def _train_epoch_impl(
logger.info("Started train epoch")
state._active_phase = ActivePhase.TRAIN

# Set all modules to train() mode
# access modules made available through AppStateMixin
tracked_modules = train_unit.tracked_modules()
prior_module_train_states = _set_module_training_mode(tracked_modules, True)

train_state = none_throws(state.train_state)

evaluate_every_n_steps = None
Expand Down Expand Up @@ -276,9 +271,4 @@ def _train_epoch_impl(
)
state._active_phase = ActivePhase.TRAIN

# Reset training mode for modules at the end of the epoch
# This ensures that side-effects made by the loop are reset before
# returning back to the user
_reset_module_training_mode(tracked_modules, prior_module_train_states)

logger.info("Ended train epoch")

0 comments on commit 5c9fe83

Please sign in to comment.