diff --git a/tests/framework/test_predict.py b/tests/framework/test_predict.py index 68e7059453..b33ba30128 100644 --- a/tests/framework/test_predict.py +++ b/tests/framework/test_predict.py @@ -8,14 +8,15 @@ # pyre-strict import unittest -from typing import Any, Iterator, Mapping, Tuple -from unittest.mock import MagicMock +from typing import Any, cast, Iterator, List, Mapping, Tuple +from unittest.mock import MagicMock, patch import torch from torch import nn from torchtnt.framework._test_utils import DummyPredictUnit, generate_random_dataloader from torchtnt.framework.callback import Callback +from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver from torchtnt.framework.predict import predict from torchtnt.framework.state import State @@ -223,6 +224,24 @@ def test_error_message(self) -> None: log.output, ) + def test_predict_ckpt_autograd_mode( + self, + ) -> None: + """ + Verify that the pytorch autograd mode used depends on having a checkpoint callback in predict. + """ + unit = DummyPredictUnit(2) + dataloader = generate_random_dataloader(10, 2, 2) + dcp_saver = DistributedCheckpointSaver(dirpath="dummy_dirpath") + + for callbacks, mock_autograd_mode in [ + ([], "torch.inference_mode"), + ([dcp_saver], "torch.no_grad"), + ]: + with patch(mock_autograd_mode) as mock_autograd_mode: + predict(unit, dataloader, callbacks=cast(List[Callback], callbacks)) + mock_autograd_mode.assert_called_once() + Batch = Tuple[torch.Tensor, torch.Tensor] diff --git a/torchtnt/framework/predict.py b/torchtnt/framework/predict.py index bf414ec72b..c4d362b3f6 100644 --- a/torchtnt/framework/predict.py +++ b/torchtnt/framework/predict.py @@ -20,6 +20,7 @@ _set_module_training_mode, ) from torchtnt.framework.callback import Callback +from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.framework.unit import TPredictData, TPredictUnit from torchtnt.framework.utils import get_timing_context @@ -80,7 +81,10 @@ def predict( call on_predict_end on unit first and then callbacks """ _log_api_usage("predict") - callback_handler = CallbackHandler(callbacks or []) + callbacks = callbacks or [] + callback_handler = CallbackHandler(callbacks) + checkpoint_cb_exists = any(isinstance(cb, BaseCheckpointer) for cb in callbacks) + state = State( entry_point=EntryPoint.PREDICT, predict_state=PhaseState( @@ -90,7 +94,13 @@ def predict( timer=timer, ) try: - _predict_impl(state, predict_unit, callback_handler) + # all_gather using inference_mode with gloo backend is not supported. Since this collective + # is necessary for checkpointing, we need to use torch.no_grad instead. + # TODO: remove this once all_gather is supported in inference_mode. + inference_ctx = torch.no_grad if checkpoint_cb_exists else torch.inference_mode + with inference_ctx(): + _predict_impl(state, predict_unit, callback_handler) + logger.info("Finished predict") if state.timer: logger.info(get_timer_summary(state.timer)) @@ -104,7 +114,6 @@ def predict( raise e -@torch.inference_mode() def _predict_impl( state: State, predict_unit: TPredictUnit,