Skip to content

Commit

Permalink
subclass PredictUnit in AutoUnit (#545)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #545

# Context
Currently if users want to train, eval, and predict all in one script, they must transfer their modules from AutoUnit to AutoPredictUnit, which is unintuitive and can be error prone

# This diff
We subclass in PredictUnit in AutoUnit, which allows users to use the existing predict_step, or override it, when calling `predict`

Reviewed By: galrotem

Differential Revision: D49441091

fbshipit-source-id: dbf744ae2841c41b8a2ddaa1558147c09fc75889
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Sep 20, 2023
1 parent df2768f commit 0b926c3
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 4 deletions.
70 changes: 70 additions & 0 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,29 @@ def test_mixed_precision_invalid_str(self) -> None:
precision="foo",
)

def test_predict_step(self) -> None:
"""
Test predict step functionality
"""
my_module = torch.nn.Linear(2, 2)
auto_unit = DummyAutoUnit(
module=my_module,
)

input_dim = 2
dataset_len = 10
batch_size = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
dataloader_iter = iter(dataloader)
pred_dataloader = (x[0] for x in dataloader_iter) # only need data, not target

with patch(
"torchtnt.framework._test_utils.DummyAutoUnit.on_predict_step_end"
) as mock_predict_step_end:
predict(auto_unit, pred_dataloader, max_steps_per_epoch=1)
mock_predict_step_end.assert_called_once()

def test_stochastic_weight_averaging_basic(self) -> None:
"""
Basic stochastic weight averaging tests
Expand Down Expand Up @@ -600,6 +623,28 @@ def test_auto_unit_timing_eval(self) -> None:
timer=Timer(),
)

def test_auto_unit_timing_predict(self) -> None:
"""
Test auto timing in AutoUnit for predict
"""
input_dim = 2
dataset_len = 10
batch_size = 2
max_steps_per_epoch = 1

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
dataloader_iter = iter(dataloader)
pred_dataloader = (x[0] for x in dataloader_iter) # only need data, not targets

my_module = torch.nn.Linear(2, 2)

predict(
TimingAutoUnit(module=my_module),
pred_dataloader,
max_steps_per_epoch=max_steps_per_epoch,
timer=Timer(),
)

@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
Expand Down Expand Up @@ -847,6 +892,31 @@ def on_eval_step_end(
# eval_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's eval_step
tc.assertNotIn("TimingAutoUnit.eval_step", recorded_timer_keys)

def on_predict_step_end(
self,
state: State,
data: Batch,
step: int,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
) -> None:
if self.predict_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
# pyre-fixme[16]: Optional type has no attribute `recorded_durations`.
recorded_timer_keys = state.timer.recorded_durations.keys()
for k in (
"TimingAutoUnit.on_predict_start",
"TimingAutoUnit.on_predict_epoch_start",
"predict.iter(dataloader)",
"predict.next(data_iter)",
"TimingAutoUnit.move_data_to_device",
"TimingAutoUnit.on_predict_step_end",
):
tc.assertIn(k, recorded_timer_keys)

# eval_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's eval_step
tc.assertNotIn("TimingAutoUnit.predict_step", recorded_timer_keys)


class TimingAutoPredictUnit(AutoPredictUnit[Batch]):
def __init__(self, module: torch.nn.Module) -> None:
Expand Down
53 changes: 49 additions & 4 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,14 @@ def _prefetch_next_batch(
class AutoUnit(
TrainUnit[TData],
EvalUnit[TData],
PredictUnit[TData],
metaclass=_ConfigureOptimizersCaller,
):
"""
The AutoUnit is a convenience for users who are training with stochastic gradient descent and would like to have model optimization
and data parallel replication handled for them.
The AutoUnit subclasses :class:`~torchtnt.framework.unit.TrainUnit` and :class:`~torchtnt.framework.unit.EvalUnit`,
and implements the ``train_step`` and ``eval_step`` methods for the user.
The AutoUnit subclasses :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, and
:class:`~torchtnt.framework.unit.PredictUnit` and implements the ``train_step``, ``eval_step``, and ``predict_step`` methods for the user.
For the ``train_step`` it runs:
Expand All @@ -279,15 +280,20 @@ class AutoUnit(
For the ``eval_step`` it only runs forward and loss computation.
For the ``predict_step`` it only runs forward computation.
To benefit from the AutoUnit, the user must subclass it and implement the ``compute_loss`` and ``configure_optimizers_and_lr_scheduler`` methods.
Additionally, the AutoUnit offers these optional hooks:
- ``on_train_step_end``
- ``on_eval_step_end``
- ``on_predict_step_end``
Then use with the :py:func:`~torchtnt.framework.train`, :py:func:`~torchtnt.framework.evaluate`, or :py:func:`~torchtnt.framework.fit` entry point as normal.
Then use with the :py:func:`~torchtnt.framework.train`, :py:func:`~torchtnt.framework.evaluate`, :py:func:`~torchtnt.framework.fit`, or
:py:func:`~torchtnt.framework.predict` entry point as normal.
For more advanced customization, directly use the :class:`~torchtnt.framework.unit.TrainUnit` and :class:`~torchtnt.framework.unit.EvalUnit` interfaces.
For more advanced customization, directly use the :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`,
and :class:`~torchtnt.framework.unit.PredictUnit` interfaces.
Args:
module: module to be used during training/evaluation.
Expand Down Expand Up @@ -736,6 +742,45 @@ def on_eval_step_end(
"""
pass

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def predict_step(self, state: State, data: TData) -> Any:
with get_timing_context(
state, f"{self.__class__.__name__}.move_data_to_device"
):
data = self.move_data_to_device(state, data, non_blocking=False)

with self.maybe_autocast_precision:
with get_timing_context(state, f"{self.__class__.__name__}.forward"):
outputs = self.module(data)

step = self.predict_progress.num_steps_completed
# users can override this, by default this is a no-op
with get_timing_context(
state, f"{self.__class__.__name__}.on_predict_step_end"
):
self.on_predict_step_end(state, data, step, outputs)
return outputs

def on_predict_step_end(
self,
state: State,
data: TData,
step: int,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
) -> None:
"""
This will be called at the end of every ``predict_step`` before returning. The user can implement this method with code to update and log their metrics,
or do anything else.
Args:
state: a State object which is passed from the ``predict_step``
data: a batch of data which is passed from the ``predict_step``
step: how many ``predict_step``s have been completed
outputs: the outputs of the model forward pass
"""
pass


def _validate_torch_compile_available() -> None:
if not is_torch_version_ge_1_13_1():
Expand Down

0 comments on commit 0b926c3

Please sign in to comment.