Skip to content

Commit

Permalink
add test for get_next_batch (#561)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #561

Introducing a test for AutoUnit's _get_next_batch and getting rid of pyre-ignore annotations in test_auto_unit

Reviewed By: JKSenthil

Differential Revision: D49851294

fbshipit-source-id: 8e686c39d62c551736ccc26361dca7c266d74a78
  • Loading branch information
galrotem authored and facebook-github-bot committed Oct 3, 2023
1 parent 8c5ef73 commit 587ea9a
Showing 1 changed file with 65 additions and 45 deletions.
110 changes: 65 additions & 45 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Any, Tuple
from typing import Any, Literal, Tuple
from unittest.mock import MagicMock, patch

import torch
Expand All @@ -18,6 +18,7 @@
COMPILE_AVAIL = True
import torch._dynamo

from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
from torch.distributed import GradBucket
from torchtnt.framework._test_utils import (
DummyAutoUnit,
Expand All @@ -44,10 +45,12 @@
from torchtnt.utils.test_utils import spawn_multi_process
from torchtnt.utils.timer import Timer

TParams = ParamSpec("TParams")


class TestAutoUnit(unittest.TestCase):
# pyre-fixme[4]: Attribute must be annotated.
cuda_available = torch.cuda.is_available()
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_app_state_mixin(self) -> None:
"""
Expand Down Expand Up @@ -274,8 +277,7 @@ def forward(self, x):
condition=cuda_available, reason="This test needs a GPU host to run."
)
@patch("torch.autocast")
# pyre-fixme[2]: Parameter must be annotated.
def test_eval_mixed_precision_bf16(self, mock_autocast) -> None:
def test_eval_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None:
"""
Test that the mixed precision autocast context is called during evaluate when precision = bf16
"""
Expand All @@ -295,10 +297,8 @@ def test_eval_mixed_precision_bf16(self, mock_autocast) -> None:
device_type="cuda", dtype=torch.bfloat16, enabled=True
)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
Expand Down Expand Up @@ -437,10 +437,8 @@ def test_configure_optimizers_and_lr_scheduler_called_once(self) -> None:
)
self.assertEqual(configure_optimizers_and_lr_scheduler_mock.call_count, 1)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
condition=distributed_available, reason="Torch distributed is needed to run"
)
def test_auto_unit_ddp(self) -> None:
"""
Expand Down Expand Up @@ -649,8 +647,7 @@ def test_auto_unit_timing_predict(self) -> None:
condition=cuda_available, reason="This test needs a GPU host to run."
)
@patch("torch.autocast")
# pyre-fixme[2]: Parameter must be annotated.
def test_predict_mixed_precision_fp16(self, mock_autocast) -> None:
def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
"""
Test that the mixed precision autocast context is called during predict when precision = fp16
"""
Expand All @@ -677,8 +674,7 @@ def test_predict_mixed_precision_fp16(self, mock_autocast) -> None:
condition=cuda_available, reason="This test needs a GPU host to run."
)
@patch("torch.compile")
# pyre-fixme[2]: Parameter must be annotated.
def test_compile_predict(self, mock_dynamo) -> None:
def test_compile_predict(self, mock_dynamo: MagicMock) -> None:
"""
e2e torch compile on predict
"""
Expand Down Expand Up @@ -719,8 +715,7 @@ def test_auto_predict_unit_timing_predict(self) -> None:
)

@patch("torch.autograd.set_detect_anomaly")
# pyre-fixme[2]: Parameter must be annotated.
def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:
def test_predict_detect_anomaly(self, mock_detect_anomaly: MagicMock) -> None:
my_module = torch.nn.Linear(2, 2)
auto_unit = AutoPredictUnit(module=my_module, detect_anomaly=True)

Expand All @@ -734,18 +729,52 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:
predict(auto_unit, predict_dl, max_steps_per_epoch=1)
mock_detect_anomaly.assert_called()

def test_get_next_batch(self) -> None:
auto_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2))
data = iter([1, 2])
state = get_dummy_train_state()
self.assertFalse(auto_unit._prefetched)
self.assertIsNone(auto_unit._next_batch)
move_data_to_device_mock = patch.object(
auto_unit,
"move_data_to_device",
side_effect=lambda state, data, non_blocking: data,
)

with move_data_to_device_mock:
batch = auto_unit._get_next_batch(state, data)
self.assertEqual(batch, 1)
self.assertEqual(auto_unit._next_batch, 2)
self.assertTrue(auto_unit._prefetched)

with move_data_to_device_mock:
batch = auto_unit._get_next_batch(state, data)
self.assertEqual(batch, 2)
self.assertIsNone(auto_unit._next_batch)
self.assertTrue(auto_unit._prefetched)
self.assertTrue(auto_unit._is_last_train_batch)

with move_data_to_device_mock, self.assertRaises(StopIteration):
auto_unit._get_next_batch(state, data)
self.assertIsNone(auto_unit._next_batch)
self.assertFalse(auto_unit._prefetched)
self.assertFalse(auto_unit._is_last_train_batch)


Batch = Tuple[torch.Tensor, torch.Tensor]


class DummyLRSchedulerAutoUnit(AutoUnit[Batch]):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
module: torch.nn.Module,
step_lr_interval: Literal["step", "epoch"] = "epoch",
) -> None:
super().__init__(module=module, step_lr_interval=step_lr_interval)

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
Expand All @@ -761,14 +790,13 @@ def configure_optimizers_and_lr_scheduler(


class DummyComplexAutoUnit(AutoUnit[Batch]):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, lr: float, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, lr: float, module: torch.nn.Module) -> None:
super().__init__(module=module)
self.lr = lr

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
Expand Down Expand Up @@ -839,14 +867,12 @@ def on_train_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
assert state.train_state
if self.train_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
# pyre-fixme[16]: Optional type has no attribute `recorded_durations`.
recorded_timer_keys = state.timer.recorded_durations.keys()
recorded_timer_keys = none_throws(state.timer).recorded_durations.keys()
for k in (
"TimingAutoUnit.on_train_start",
"TimingAutoUnit.on_train_epoch_start",
Expand All @@ -871,13 +897,11 @@ def on_eval_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
if self.eval_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
# pyre-fixme[16]: Optional type has no attribute `recorded_durations`.
recorded_timer_keys = state.timer.recorded_durations.keys()
recorded_timer_keys = none_throws(state.timer).recorded_durations.keys()
for k in (
"TimingAutoUnit.on_eval_start",
"TimingAutoUnit.on_eval_epoch_start",
Expand All @@ -897,13 +921,11 @@ def on_predict_step_end(
state: State,
data: Batch,
step: int,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
if self.predict_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
# pyre-fixme[16]: Optional type has no attribute `recorded_durations`.
recorded_timer_keys = state.timer.recorded_durations.keys()
recorded_timer_keys = none_throws(state.timer).recorded_durations.keys()
for k in (
"TimingAutoUnit.on_predict_start",
"TimingAutoUnit.on_predict_epoch_start",
Expand Down Expand Up @@ -944,13 +966,11 @@ def on_predict_step_end(
state: State,
data: TPredictData,
step: int,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
outputs: torch.Tensor,
) -> None:
if self.predict_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
# pyre-fixme[16]: Optional type has no attribute `recorded_durations`.
recorded_timer_keys = state.timer.recorded_durations.keys()
recorded_timer_keys = none_throws(state.timer).recorded_durations.keys()
for k in (
"AutoPredictUnit.on_predict_start",
"AutoPredictUnit.on_predict_epoch_start",
Expand Down

0 comments on commit 587ea9a

Please sign in to comment.