Skip to content

Commit

Permalink
subclass PredictUnit in AutoUnit (pytorch#545)
Browse files Browse the repository at this point in the history
Summary:

# Context
Currently if users want to train, eval, and predict all in one script, they must transfer their modules from AutoUnit to AutoPredictUnit, which is unintuitive and can be error prone

# This diff
We subclass in PredictUnit in AutoUnit, which allows users to use the existing predict_step, or override it, when calling `predict`

Differential Revision: D49441091
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Sep 20, 2023
1 parent 9d6cd91 commit f77d11b
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
70 changes: 70 additions & 0 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,29 @@ def test_mixed_precision_invalid_str(self) -> None:
precision="foo",
)

def test_predict_step(self) -> None:
"""
Test predict step functionality
"""
my_module = torch.nn.Linear(2, 2)
auto_unit = DummyAutoUnit(
module=my_module,
)

input_dim = 2
dataset_len = 10
batch_size = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
dataloader_iter = iter(dataloader)
pred_dataloader = (x[0] for x in dataloader_iter) # only need data, not target

with patch(
"torchtnt.framework._test_utils.DummyAutoUnit.on_predict_step_end"
) as mock_predict_step_end:
predict(auto_unit, pred_dataloader, max_steps_per_epoch=1)
mock_predict_step_end.assert_called_once()

def test_stochastic_weight_averaging_basic(self) -> None:
"""
Basic stochastic weight averaging tests
Expand Down Expand Up @@ -600,6 +623,28 @@ def test_auto_unit_timing_eval(self) -> None:
timer=Timer(),
)

def test_auto_unit_timing_predict(self) -> None:
"""
Test auto timing in AutoUnit for predict
"""
input_dim = 2
dataset_len = 10
batch_size = 2
max_steps_per_epoch = 1

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
dataloader_iter = iter(dataloader)
pred_dataloader = (x[0] for x in dataloader_iter) # only need data, not targets

my_module = torch.nn.Linear(2, 2)

predict(
TimingAutoUnit(module=my_module),
pred_dataloader,
max_steps_per_epoch=max_steps_per_epoch,
timer=Timer(),
)

@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
Expand Down Expand Up @@ -847,6 +892,31 @@ def on_eval_step_end(
# eval_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's eval_step
tc.assertNotIn("TimingAutoUnit.eval_step", recorded_timer_keys)

def on_predict_step_end(
self,
state: State,
data: Batch,
step: int,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
) -> None:
if self.predict_progress.num_steps_completed_in_epoch == 1:
tc = unittest.TestCase()
# pyre-fixme[16]: Optional type has no attribute `recorded_durations`.
recorded_timer_keys = state.timer.recorded_durations.keys()
for k in (
"TimingAutoUnit.on_predict_start",
"TimingAutoUnit.on_predict_epoch_start",
"predict.iter(dataloader)",
"predict.next(data_iter)",
"TimingAutoUnit.move_data_to_device",
"TimingAutoUnit.on_predict_step_end",
):
tc.assertIn(k, recorded_timer_keys)

# eval_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's eval_step
tc.assertNotIn("TimingAutoUnit.predict_step", recorded_timer_keys)


class TimingAutoPredictUnit(AutoPredictUnit[Batch]):
def __init__(self, module: torch.nn.Module) -> None:
Expand Down
40 changes: 40 additions & 0 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def _prefetch_next_batch(
class AutoUnit(
TrainUnit[TData],
EvalUnit[TData],
PredictUnit[TData],
metaclass=_ConfigureOptimizersCaller,
):
"""
Expand Down Expand Up @@ -736,6 +737,45 @@ def on_eval_step_end(
"""
pass

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def predict_step(self, state: State, data: TData) -> Any:
with get_timing_context(
state, f"{self.__class__.__name__}.move_data_to_device"
):
data = self.move_data_to_device(state, data, non_blocking=False)

with self.maybe_autocast_precision:
with get_timing_context(state, f"{self.__class__.__name__}.forward"):
outputs = self.module(data)

step = self.predict_progress.num_steps_completed
# users can override this, by default this is a no-op
with get_timing_context(
state, f"{self.__class__.__name__}.on_predict_step_end"
):
self.on_predict_step_end(state, data, step, outputs)
return outputs

def on_predict_step_end(
self,
state: State,
data: TData,
step: int,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
) -> None:
"""
This will be called at the end of every ``predict_step`` before returning. The user can implement this method with code to update and log their metrics,
or do anything else.
Args:
state: a State object which is passed from the ``predict_step``
data: a batch of data which is passed from the ``predict_step``
step: how many ``predict_step``s have been completed
outputs: the outputs of the model forward pass
"""
pass


def _validate_torch_compile_available() -> None:
if not is_torch_version_ge_1_13_1():
Expand Down

0 comments on commit f77d11b

Please sign in to comment.