Skip to content

Commit

Permalink
Breaking Change - Move progress to Unit (#459)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #459

1. Progress is no longer tracked on the State but on the Unit
2. Removed the `get_current_progress` utility

Differential Revision: D47358509

fbshipit-source-id: 3fa1836286d6cd89e8c73cea28dd789322ea75cc
  • Loading branch information
daniellepintz authored and facebook-github-bot committed Jul 17, 2023
1 parent 3cbec3c commit a80fc50
Show file tree
Hide file tree
Showing 33 changed files with 245 additions and 307 deletions.
7 changes: 5 additions & 2 deletions examples/auto_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.utils.data.dataset import Dataset, TensorDataset
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import AutoUnit, fit, init_fit_state, State
from torchtnt.framework.utils import get_current_progress
from torchtnt.framework.state import EntryPoint
from torchtnt.utils import init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger

Expand Down Expand Up @@ -120,7 +120,10 @@ def on_eval_step_end(
self.eval_accuracy.update(outputs, targets)

def on_eval_end(self, state: State) -> None:
step = get_current_progress(state).num_steps_completed
if state.entry_point == EntryPoint.FIT:
step = self.train_progress.num_steps_completed
else:
step = self.eval_progress.num_steps_completed
accuracy = self.eval_accuracy.compute()
self.tb_logger.log("eval_accuracy", accuracy, step)
self.eval_accuracy.reset()
Expand Down
6 changes: 0 additions & 6 deletions examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,6 @@ def on_train_epoch_end(self, state: State) -> None:
def on_eval_step_end(
self, state: State, data: Batch, step: int, loss: torch.Tensor, outputs: Any
) -> None:
# step_count = state.eval_state.progress.num_steps_completed
# data = copy_data_to_device(data, self.device)
# inputs, targets = data

# outputs = self.module(inputs)
# loss = torch.nn.functional.nll_loss(outputs, targets)
if step % self.log_every_n_steps == 0:
self.tb_logger.log("evaluation loss", loss, step)

Expand Down
4 changes: 2 additions & 2 deletions examples/torchdata_train_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ def train_step(self, state: State, data: Batch) -> None:

# update metrics & logs
self.train_accuracy.update(outputs, targets)
step_count = state.train_state.progress.num_steps_completed
step_count = self.train_progress.num_steps_completed
if (step_count + 1) % self.log_every_n_steps == 0:
accuracy = self.train_accuracy.compute()
self.tb_logger.log("loss", loss, step_count)
self.tb_logger.log("accuracy", accuracy, step_count)

def on_train_epoch_end(self, state: State) -> None:
# compute and log the metrics at the end of epoch
step_count = state.train_state.progress.num_steps_completed
step_count = self.train_progress.num_steps_completed
accuracy = self.train_accuracy.compute()
self.tb_logger.log("accuracy_epoch", accuracy, step_count)

Expand Down
5 changes: 2 additions & 3 deletions examples/torchrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

from torchtnt.framework import EvalUnit, fit, init_fit_state, State, TrainUnit
from torchtnt.framework.callbacks import TQDMProgressBar
from torchtnt.framework.utils import get_current_progress
from torchtnt.utils import (
get_process_group_backend_from_device,
init_from_env,
Expand Down Expand Up @@ -202,7 +201,7 @@ def __init__(
self.log_every_n_steps = log_every_n_steps

def train_step(self, state: State, data: Iterator[Batch]) -> None:
step = get_current_progress(state).num_steps_completed
step = self.train_progress.num_steps_completed
loss, logits, labels = self.pipeline.progress(data)
preds = torch.sigmoid(logits)
self.train_auroc.update(preds, labels)
Expand All @@ -217,7 +216,7 @@ def on_train_epoch_end(self, state: State) -> None:
self.train_auroc.reset()

def eval_step(self, state: State, data: Iterator[Batch]) -> None:
step = get_current_progress(state).num_steps_completed
step = self.eval_progress.num_steps_completed
loss, _, _ = self.pipeline.progress(data)
if step % self.log_every_n_steps == 0:
self.tb_logger.log("evaluation_loss", loss, step)
Expand Down
4 changes: 2 additions & 2 deletions examples/train_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ def train_step(self, state: State, data: Batch) -> None:

# update metrics & logs
self.train_accuracy.update(outputs, targets)
step_count = state.train_state.progress.num_steps_completed
step_count = self.train_progress.num_steps_completed
if (step_count + 1) % self.log_every_n_steps == 0:
acc = self.train_accuracy.compute()
self.tb_logger.log("loss", loss, step_count)
self.tb_logger.log("accuracy", acc, step_count)

def on_train_epoch_end(self, state: State) -> None:
# compute and log the metric at the end of the epoch
step_count = state.train_state.progress.num_steps_completed
step_count = self.train_progress.num_steps_completed
acc = self.train_accuracy.compute()
self.tb_logger.log("accuracy_epoch", acc, step_count)

Expand Down
6 changes: 3 additions & 3 deletions tests/framework/callbacks/test_csv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_csv_writer(self) -> None:
dataset_len = 10
batch_size = 2

my_unit = MagicMock(spec=DummyPredictUnit)
my_unit = DummyPredictUnit(2)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_predict_state(dataloader=dataloader)

Expand All @@ -73,7 +73,7 @@ def test_csv_writer_single_row(self) -> None:
dataset_len = 10
batch_size = 2

my_unit = MagicMock(spec=DummyPredictUnit)
my_unit = DummyPredictUnit(2)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_predict_state(dataloader=dataloader)

Expand All @@ -96,7 +96,7 @@ def test_csv_writer_with_no_output_rows_def(self) -> None:
dataset_len = 10
batch_size = 2

my_unit = MagicMock(spec=DummyPredictUnit)
my_unit = DummyPredictUnit(2)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_predict_state(dataloader=dataloader)

Expand Down
16 changes: 8 additions & 8 deletions tests/framework/callbacks/test_garbage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_garbage_collector_call_count_train(self) -> None:
max_epochs = 2
expected_num_total_steps = dataset_len / batch_size * max_epochs

my_unit = MagicMock(spec=DummyTrainUnit)
my_unit = DummyTrainUnit(2)
gc_callback_mock = MagicMock(spec=GarbageCollector)

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
Expand All @@ -57,7 +57,7 @@ def test_garbage_collector_enabled_train(self) -> None:
batch_size = 2
max_epochs = 2

my_unit = MagicMock(spec=DummyTrainUnit)
my_unit = DummyTrainUnit(2)
gc_callback = GarbageCollector(2)

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
Expand All @@ -76,7 +76,7 @@ def test_garbage_collector_call_count_evaluate(self) -> None:
batch_size = 2
expected_num_total_steps = dataset_len / batch_size

my_unit = MagicMock(spec=DummyEvalUnit)
my_unit = DummyEvalUnit(2)
gc_callback_mock = MagicMock(spec=GarbageCollector)

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
Expand All @@ -97,7 +97,7 @@ def test_garbage_collector_enabled_evaluate(self) -> None:
dataset_len = 10
batch_size = 2

my_unit = MagicMock(spec=DummyEvalUnit)
my_unit = DummyEvalUnit(2)
gc_callback = GarbageCollector(2)

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
Expand All @@ -116,7 +116,7 @@ def test_garbage_collector_call_count_predict(self) -> None:
batch_size = 2
expected_num_total_steps = dataset_len / batch_size

my_unit = MagicMock(spec=DummyPredictUnit)
my_unit = DummyPredictUnit(2)
gc_callback_mock = MagicMock(spec=GarbageCollector)

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
Expand All @@ -137,7 +137,7 @@ def test_garbage_collector_enabled_predict(self) -> None:
dataset_len = 10
batch_size = 2

my_unit = MagicMock(spec=DummyPredictUnit)
my_unit = DummyPredictUnit(2)
gc_callback = GarbageCollector(2)

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
Expand All @@ -163,7 +163,7 @@ def test_garbage_collector_call_count_fit(self) -> None:
)
gc_step_interval = 4

my_unit = MagicMock(spec=DummyFitUnit)
my_unit = DummyFitUnit(2)
gc_callback = GarbageCollector(gc_step_interval)

train_dataloader = generate_random_dataloader(
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_garbage_collector_enabled_fit(self) -> None:
max_epochs = 2
evaluate_every_n_epochs = 1

my_unit = MagicMock(spec=DummyFitUnit)
my_unit = DummyFitUnit(2)
gc_callback = GarbageCollector(2)

train_dataloader = generate_random_dataloader(
Expand Down
4 changes: 2 additions & 2 deletions tests/framework/callbacks/test_pytorch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_profiler_evaluate(self) -> None:
batch_size = 2
expected_num_total_steps = dataset_len / batch_size

my_unit = MagicMock(spec=DummyEvalUnit)
my_unit = DummyEvalUnit(2)
profiler_mock = MagicMock(spec=torch.profiler.profile)

profiler = PyTorchProfiler(profiler=profiler_mock)
Expand All @@ -75,7 +75,7 @@ def test_profiler_predict(self) -> None:
batch_size = 2
expected_num_total_steps = dataset_len / batch_size

my_unit = MagicMock(spec=DummyPredictUnit)
my_unit = DummyPredictUnit(2)
profiler_mock = MagicMock(spec=torch.profiler.profile)

profiler = PyTorchProfiler(profiler=profiler_mock)
Expand Down
6 changes: 2 additions & 4 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,10 @@ def test_save_restore(self) -> None:
)
train(state, my_unit, callbacks=[snapshot_cb])

end_num_steps_completed = state.train_state.progress.num_steps_completed
end_num_steps_completed = my_unit.train_progress.num_steps_completed
self.assertGreater(len(expected_paths), 0)
snapshot_cb.restore(expected_paths[0], state, my_unit)
restored_num_steps_completed = (
state.train_state.progress.num_steps_completed
)
restored_num_steps_completed = my_unit.train_progress.num_steps_completed
# A snapshot is saved every n steps
# so the first snapshot's progress will be equal to save_every_n_train_steps
self.assertNotEqual(restored_num_steps_completed, end_num_steps_completed)
Expand Down
14 changes: 6 additions & 8 deletions tests/framework/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# LICENSE file in the root directory of this source tree.

import unittest
from unittest.mock import MagicMock

from torchtnt.framework._test_utils import (
DummyEvalUnit,
Expand Down Expand Up @@ -39,7 +38,7 @@ def test_progress_bar_train(self) -> None:
),
)

my_unit = MagicMock(spec=DummyTrainUnit)
my_unit = DummyTrainUnit(2)
progress_bar = TQDMProgressBar()
progress_bar.on_train_epoch_start(state, my_unit)
self.assertEqual(progress_bar._train_progress_bar.total, expected_total)
Expand All @@ -56,7 +55,7 @@ def test_progress_bar_train_integration(self) -> None:
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_train_state(dataloader=dataloader, max_epochs=max_epochs)

my_unit = MagicMock(spec=DummyTrainUnit)
my_unit = DummyTrainUnit(2)
progress_bar = TQDMProgressBar()
train(state, my_unit, callbacks=[progress_bar])

Expand All @@ -79,7 +78,7 @@ def test_progress_bar_evaluate(self) -> None:
),
)

my_unit = MagicMock(spec=DummyEvalUnit)
my_unit = DummyEvalUnit(2)
progress_bar = TQDMProgressBar()
progress_bar.on_eval_epoch_start(state, my_unit)
self.assertEqual(progress_bar._eval_progress_bar.total, expected_total)
Expand All @@ -103,7 +102,7 @@ def test_progress_bar_predict(self) -> None:
),
)

my_unit = MagicMock(spec=DummyPredictUnit)
my_unit = DummyPredictUnit(2)
progress_bar = TQDMProgressBar()
progress_bar.on_predict_epoch_start(state, my_unit)
self.assertEqual(progress_bar._predict_progress_bar.total, expected_total)
Expand All @@ -126,9 +125,8 @@ def test_progress_bar_mid_progress(self) -> None:
max_epochs=max_epochs,
),
)
state.predict_state.progress._num_steps_completed = 2

my_unit = MagicMock(spec=DummyPredictUnit)
my_unit = DummyPredictUnit(2)
my_unit.predict_progress._num_steps_completed = 2
progress_bar = TQDMProgressBar()
progress_bar.on_predict_epoch_start(state, my_unit)
self.assertEqual(progress_bar._predict_progress_bar.total, expected_total)
Expand Down
6 changes: 3 additions & 3 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _test_ddp_no_sync() -> None:
auto_unit.train_step(state=state, data=dummy_iterator)
no_sync_mock.assert_called_once()

state.train_state.progress.increment_step()
auto_unit.train_progress.increment_step()
# for the second step no_sync should not be called since we run optimizer step
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(state=state, data=dummy_iterator)
Expand Down Expand Up @@ -496,7 +496,7 @@ def _test_fsdp_no_sync() -> None:
auto_unit.train_step(state=state, data=dummy_iterator)
no_sync_mock.assert_called_once()

state.train_state.progress.increment_step()
auto_unit.train_progress.increment_step()
# for the second step no_sync should not be called since we run optimizer step
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(state=state, data=dummy_iterator)
Expand Down Expand Up @@ -1108,7 +1108,7 @@ def compute_loss(
tc = unittest.TestCase()
tc.assertEqual(
self._is_last_train_batch,
state.train_state.progress.num_steps_completed_in_epoch + 1
self.train_progress.num_steps_completed_in_epoch + 1
== self.expected_steps_per_epoch,
)
inputs, targets = data
Expand Down
30 changes: 14 additions & 16 deletions tests/framework/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def test_evaluate(self) -> None:
state = init_eval_state(dataloader=dataloader)
evaluate(state, my_unit)

self.assertEqual(state.eval_state.progress.num_epochs_completed, 1)
self.assertEqual(state.eval_state.progress.num_steps_completed_in_epoch, 0)
self.assertEqual(state.eval_state.progress.num_steps_completed, expected_steps)
self.assertEqual(my_unit.eval_progress.num_epochs_completed, 1)
self.assertEqual(my_unit.eval_progress.num_steps_completed_in_epoch, 0)
self.assertEqual(my_unit.eval_progress.num_steps_completed, expected_steps)
self.assertEqual(state.entry_point, EntryPoint.EVALUATE)

# step_output should be reset to None
Expand All @@ -62,11 +62,9 @@ def test_evaluate_max_steps_per_epoch(self) -> None:
)
evaluate(state, my_unit)

self.assertEqual(state.eval_state.progress.num_epochs_completed, 1)
self.assertEqual(state.eval_state.progress.num_steps_completed_in_epoch, 0)
self.assertEqual(
state.eval_state.progress.num_steps_completed, max_steps_per_epoch
)
self.assertEqual(my_unit.eval_progress.num_epochs_completed, 1)
self.assertEqual(my_unit.eval_progress.num_steps_completed_in_epoch, 0)
self.assertEqual(my_unit.eval_progress.num_steps_completed, max_steps_per_epoch)
self.assertEqual(state.entry_point, EntryPoint.EVALUATE)

# step_output should be reset to None
Expand All @@ -93,10 +91,10 @@ def test_evaluate_stop(self) -> None:
)
evaluate(state, my_unit)

self.assertEqual(state.eval_state.progress.num_epochs_completed, 1)
self.assertEqual(state.eval_state.progress.num_steps_completed_in_epoch, 0)
self.assertEqual(my_unit.eval_progress.num_epochs_completed, 1)
self.assertEqual(my_unit.eval_progress.num_steps_completed_in_epoch, 0)
self.assertEqual(
my_unit.steps_processed, state.eval_state.progress.num_steps_completed
my_unit.steps_processed, my_unit.eval_progress.num_steps_completed
)
self.assertEqual(my_unit.steps_processed, steps_before_stopping)

Expand Down Expand Up @@ -129,9 +127,9 @@ def eval_step(
state = init_eval_state(dataloader=dataloader)
evaluate(state, my_unit)

self.assertEqual(state.eval_state.progress.num_epochs_completed, 1)
self.assertEqual(state.eval_state.progress.num_steps_completed_in_epoch, 0)
self.assertEqual(state.eval_state.progress.num_steps_completed, expected_steps)
self.assertEqual(my_unit.eval_progress.num_epochs_completed, 1)
self.assertEqual(my_unit.eval_progress.num_steps_completed_in_epoch, 0)
self.assertEqual(my_unit.eval_progress.num_steps_completed, expected_steps)

# step_output should be reset to None
self.assertEqual(state.eval_state.step_output, None)
Expand All @@ -148,7 +146,7 @@ def test_evaluate_with_callback(self) -> None:
max_steps_per_epoch = 6
expected_num_steps = dataset_len / batch_size

my_unit = MagicMock()
my_unit = DummyEvalUnit(2)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_eval_state(
dataloader=dataloader, max_steps_per_epoch=max_steps_per_epoch
Expand Down Expand Up @@ -225,7 +223,7 @@ def eval_step(

assert state.eval_state
if (
state.eval_state.progress.num_steps_completed_in_epoch + 1
self.eval_progress.num_steps_completed_in_epoch + 1
== self.steps_before_stopping
):
state.stop()
Expand Down
Loading

0 comments on commit a80fc50

Please sign in to comment.