diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 37622e9eed..a71bea6ce4 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Any, Tuple +from typing import Any, Literal, Tuple from unittest.mock import MagicMock, patch import torch @@ -18,6 +18,7 @@ COMPILE_AVAIL = True import torch._dynamo +from pyre_extensions import none_throws, ParameterSpecification as ParamSpec from torch.distributed import GradBucket from torchtnt.framework._test_utils import ( DummyAutoUnit, @@ -44,10 +45,12 @@ from torchtnt.utils.test_utils import spawn_multi_process from torchtnt.utils.timer import Timer +TParams = ParamSpec("TParams") + class TestAutoUnit(unittest.TestCase): - # pyre-fixme[4]: Attribute must be annotated. - cuda_available = torch.cuda.is_available() + cuda_available: bool = torch.cuda.is_available() + distributed_available: bool = torch.distributed.is_available() def test_app_state_mixin(self) -> None: """ @@ -274,8 +277,7 @@ def forward(self, x): condition=cuda_available, reason="This test needs a GPU host to run." ) @patch("torch.autocast") - # pyre-fixme[2]: Parameter must be annotated. - def test_eval_mixed_precision_bf16(self, mock_autocast) -> None: + def test_eval_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None: """ Test that the mixed precision autocast context is called during evaluate when precision = bf16 """ @@ -295,10 +297,8 @@ def test_eval_mixed_precision_bf16(self, mock_autocast) -> None: device_type="cuda", dtype=torch.bfloat16, enabled=True ) - # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`. @unittest.skipUnless( - torch.distributed.is_available(), reason="Torch distributed is needed to run" + condition=distributed_available, reason="Torch distributed is needed to run" ) @unittest.skipUnless( condition=cuda_available, reason="This test needs a GPU host to run." @@ -437,10 +437,8 @@ def test_configure_optimizers_and_lr_scheduler_called_once(self) -> None: ) self.assertEqual(configure_optimizers_and_lr_scheduler_mock.call_count, 1) - # pyre-fixme[56]: Pyre was not able to infer the type of argument - # `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`. @unittest.skipUnless( - torch.distributed.is_available(), reason="Torch distributed is needed to run" + condition=distributed_available, reason="Torch distributed is needed to run" ) def test_auto_unit_ddp(self) -> None: """ @@ -649,8 +647,7 @@ def test_auto_unit_timing_predict(self) -> None: condition=cuda_available, reason="This test needs a GPU host to run." ) @patch("torch.autocast") - # pyre-fixme[2]: Parameter must be annotated. - def test_predict_mixed_precision_fp16(self, mock_autocast) -> None: + def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None: """ Test that the mixed precision autocast context is called during predict when precision = fp16 """ @@ -677,8 +674,7 @@ def test_predict_mixed_precision_fp16(self, mock_autocast) -> None: condition=cuda_available, reason="This test needs a GPU host to run." ) @patch("torch.compile") - # pyre-fixme[2]: Parameter must be annotated. - def test_compile_predict(self, mock_dynamo) -> None: + def test_compile_predict(self, mock_dynamo: MagicMock) -> None: """ e2e torch compile on predict """ @@ -719,8 +715,7 @@ def test_auto_predict_unit_timing_predict(self) -> None: ) @patch("torch.autograd.set_detect_anomaly") - # pyre-fixme[2]: Parameter must be annotated. - def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None: + def test_predict_detect_anomaly(self, mock_detect_anomaly: MagicMock) -> None: my_module = torch.nn.Linear(2, 2) auto_unit = AutoPredictUnit(module=my_module, detect_anomaly=True) @@ -734,18 +729,52 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None: predict(auto_unit, predict_dl, max_steps_per_epoch=1) mock_detect_anomaly.assert_called() + def test_get_next_batch(self) -> None: + auto_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2)) + data = iter([1, 2]) + state = get_dummy_train_state() + self.assertFalse(auto_unit._prefetched) + self.assertIsNone(auto_unit._next_batch) + move_data_to_device_mock = patch.object( + auto_unit, + "move_data_to_device", + side_effect=lambda state, data, non_blocking: data, + ) + + with move_data_to_device_mock: + batch = auto_unit._get_next_batch(state, data) + self.assertEqual(batch, 1) + self.assertEqual(auto_unit._next_batch, 2) + self.assertTrue(auto_unit._prefetched) + + with move_data_to_device_mock: + batch = auto_unit._get_next_batch(state, data) + self.assertEqual(batch, 2) + self.assertIsNone(auto_unit._next_batch) + self.assertTrue(auto_unit._prefetched) + self.assertTrue(auto_unit._is_last_train_batch) + + with move_data_to_device_mock, self.assertRaises(StopIteration): + auto_unit._get_next_batch(state, data) + self.assertIsNone(auto_unit._next_batch) + self.assertFalse(auto_unit._prefetched) + self.assertFalse(auto_unit._is_last_train_batch) + Batch = Tuple[torch.Tensor, torch.Tensor] class DummyLRSchedulerAutoUnit(AutoUnit[Batch]): - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + module: torch.nn.Module, + step_lr_interval: Literal["step", "epoch"] = "epoch", + ) -> None: + super().__init__(module=module, step_lr_interval=step_lr_interval) - # pyre-fixme[3]: Return annotation cannot contain `Any`. - def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]: + def compute_loss( + self, state: State, data: Batch + ) -> Tuple[torch.Tensor, torch.Tensor]: inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.cross_entropy(outputs, targets) @@ -761,14 +790,13 @@ def configure_optimizers_and_lr_scheduler( class DummyComplexAutoUnit(AutoUnit[Batch]): - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, lr: float, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, lr: float, module: torch.nn.Module) -> None: + super().__init__(module=module) self.lr = lr - # pyre-fixme[3]: Return annotation cannot contain `Any`. - def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]: + def compute_loss( + self, state: State, data: Batch + ) -> Tuple[torch.Tensor, torch.Tensor]: inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.cross_entropy(outputs, targets) @@ -839,14 +867,12 @@ def on_train_step_end( data: Batch, step: int, loss: torch.Tensor, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: torch.Tensor, ) -> None: assert state.train_state if self.train_progress.num_steps_completed_in_epoch == 1: tc = unittest.TestCase() - # pyre-fixme[16]: Optional type has no attribute `recorded_durations`. - recorded_timer_keys = state.timer.recorded_durations.keys() + recorded_timer_keys = none_throws(state.timer).recorded_durations.keys() for k in ( "TimingAutoUnit.on_train_start", "TimingAutoUnit.on_train_epoch_start", @@ -871,13 +897,11 @@ def on_eval_step_end( data: Batch, step: int, loss: torch.Tensor, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: torch.Tensor, ) -> None: if self.eval_progress.num_steps_completed_in_epoch == 1: tc = unittest.TestCase() - # pyre-fixme[16]: Optional type has no attribute `recorded_durations`. - recorded_timer_keys = state.timer.recorded_durations.keys() + recorded_timer_keys = none_throws(state.timer).recorded_durations.keys() for k in ( "TimingAutoUnit.on_eval_start", "TimingAutoUnit.on_eval_epoch_start", @@ -897,13 +921,11 @@ def on_predict_step_end( state: State, data: Batch, step: int, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: torch.Tensor, ) -> None: if self.predict_progress.num_steps_completed_in_epoch == 1: tc = unittest.TestCase() - # pyre-fixme[16]: Optional type has no attribute `recorded_durations`. - recorded_timer_keys = state.timer.recorded_durations.keys() + recorded_timer_keys = none_throws(state.timer).recorded_durations.keys() for k in ( "TimingAutoUnit.on_predict_start", "TimingAutoUnit.on_predict_epoch_start", @@ -944,13 +966,11 @@ def on_predict_step_end( state: State, data: TPredictData, step: int, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, + outputs: torch.Tensor, ) -> None: if self.predict_progress.num_steps_completed_in_epoch == 1: tc = unittest.TestCase() - # pyre-fixme[16]: Optional type has no attribute `recorded_durations`. - recorded_timer_keys = state.timer.recorded_durations.keys() + recorded_timer_keys = none_throws(state.timer).recorded_durations.keys() for k in ( "AutoPredictUnit.on_predict_start", "AutoPredictUnit.on_predict_epoch_start",