Skip to content

Commit

Permalink
Use no_grad instead of inference_mode for predict with checkpointing (#…
Browse files Browse the repository at this point in the history
…912)

Summary: Pull Request resolved: #912

Reviewed By: mvsfb, JKSenthil

Differential Revision: D63491475
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Oct 4, 2024
1 parent 1f06115 commit b136b4e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
23 changes: 21 additions & 2 deletions tests/framework/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
# pyre-strict

import unittest
from typing import Any, Iterator, Mapping, Tuple
from unittest.mock import MagicMock
from typing import Any, cast, Iterator, List, Mapping, Tuple
from unittest.mock import MagicMock, patch

import torch
from torch import nn

from torchtnt.framework._test_utils import DummyPredictUnit, generate_random_dataloader
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver

from torchtnt.framework.predict import predict
from torchtnt.framework.state import State
Expand Down Expand Up @@ -223,6 +224,24 @@ def test_error_message(self) -> None:
log.output,
)

def test_predict_ckpt_autograd_mode(
self,
) -> None:
"""
Verify that the pytorch autograd mode used depends on having a checkpoint callback in predict.
"""
unit = DummyPredictUnit(2)
dataloader = generate_random_dataloader(10, 2, 2)
dcp_saver = DistributedCheckpointSaver(dirpath="dummy_dirpath")

for callbacks, mock_autograd_mode in [
([], "torch.inference_mode"),
([dcp_saver], "torch.no_grad"),
]:
with patch(mock_autograd_mode) as mock_autograd_mode:
predict(unit, dataloader, callbacks=cast(List[Callback], callbacks))
mock_autograd_mode.assert_called_once()


Batch = Tuple[torch.Tensor, torch.Tensor]

Expand Down
15 changes: 12 additions & 3 deletions torchtnt/framework/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_set_module_training_mode,
)
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.framework.unit import TPredictData, TPredictUnit
from torchtnt.framework.utils import get_timing_context
Expand Down Expand Up @@ -80,7 +81,10 @@ def predict(
call on_predict_end on unit first and then callbacks
"""
_log_api_usage("predict")
callback_handler = CallbackHandler(callbacks or [])
callbacks = callbacks or []
callback_handler = CallbackHandler(callbacks)
checkpoint_cb_exists = any(isinstance(cb, BaseCheckpointer) for cb in callbacks)

state = State(
entry_point=EntryPoint.PREDICT,
predict_state=PhaseState(
Expand All @@ -90,7 +94,13 @@ def predict(
timer=timer,
)
try:
_predict_impl(state, predict_unit, callback_handler)
# all_gather using inference_mode with gloo backend is not supported. Since this collective
# is necessary for checkpointing, we need to use torch.no_grad instead.
# TODO: remove this once all_gather is supported in inference_mode.
inference_ctx = torch.no_grad if checkpoint_cb_exists else torch.inference_mode
with inference_ctx():
_predict_impl(state, predict_unit, callback_handler)

logger.info("Finished predict")
if state.timer:
logger.info(get_timer_summary(state.timer))
Expand All @@ -104,7 +114,6 @@ def predict(
raise e


@torch.inference_mode()
def _predict_impl(
state: State,
predict_unit: TPredictUnit,
Expand Down

0 comments on commit b136b4e

Please sign in to comment.