Skip to content

Commit

Permalink
progress util for steps in fit (pytorch#538)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#538

Create a util for getting the number of estimated steps in fit

Reviewed By: JKSenthil

Differential Revision: D49246239

fbshipit-source-id: 765ebcc65d2138022172fa49176ec2325f39599f
  • Loading branch information
galrotem authored and facebook-github-bot committed Sep 14, 2023
1 parent f0d3729 commit 3c99e52
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 0 deletions.
104 changes: 104 additions & 0 deletions tests/utils/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
# LICENSE file in the root directory of this source tree.

import unittest
from unittest.mock import patch

from torchtnt.framework._test_utils import generate_random_dataloader

from torchtnt.utils.progress import (
estimated_steps_in_epoch,
estimated_steps_in_fit,
estimated_steps_in_loop,
Progress,
)
Expand Down Expand Up @@ -163,3 +165,105 @@ def test_estimated_steps_in_loop(self) -> None:
),
None,
)

def test_estimated_steps_in_fit(self) -> None:
dl = generate_random_dataloader(
num_samples=1,
input_dim=1,
batch_size=1,
)

with patch(
"torchtnt.utils.progress.estimated_steps_in_loop",
side_effect=[100, 20]
* 4, # for 4 test cases, make sure that number of steps returned is 100 for training and 20 for eval
):
self.assertEqual(
estimated_steps_in_fit(
train_dataloader=dl,
eval_dataloader=dl,
epochs=4,
max_steps=None,
max_train_steps_per_epoch=None,
max_eval_steps_per_epoch=None,
eval_every_n_steps=10,
eval_every_n_epochs=2,
),
340, # 100 (training) + 20 * 12 (steps per eval epoch * number of eval epochs: 100/10 + 4/2)
)

self.assertEqual(
estimated_steps_in_fit(
train_dataloader=dl,
eval_dataloader=dl,
epochs=3,
max_steps=None,
max_train_steps_per_epoch=None,
max_eval_steps_per_epoch=None,
eval_every_n_steps=None,
eval_every_n_epochs=2,
),
120, # 100 (training) + 20 (single eval epoch)
)

self.assertEqual(
estimated_steps_in_fit(
train_dataloader=dl,
eval_dataloader=dl,
epochs=3,
max_steps=None,
max_train_steps_per_epoch=None,
max_eval_steps_per_epoch=None,
eval_every_n_steps=49,
eval_every_n_epochs=None,
),
140, # 100 (training) + 20 * 2 (two eval epochs)
)

self.assertEqual(
estimated_steps_in_fit(
train_dataloader=dl,
eval_dataloader=dl,
epochs=3,
max_steps=None,
max_train_steps_per_epoch=None,
max_eval_steps_per_epoch=None,
eval_every_n_steps=None,
eval_every_n_epochs=None,
),
100, # just training
)

with patch(
"torchtnt.utils.progress.estimated_steps_in_loop", side_effect=[100, None]
):
self.assertEqual(
estimated_steps_in_fit(
train_dataloader=dl,
eval_dataloader=dl,
epochs=4,
max_steps=None,
max_train_steps_per_epoch=None,
max_eval_steps_per_epoch=None,
eval_every_n_steps=10,
eval_every_n_epochs=2,
),
None, # if the returned number of eval steps per eval epoch is None, we return None
)

with patch(
"torchtnt.utils.progress.estimated_steps_in_loop", side_effect=[None, 20]
):
self.assertEqual(
estimated_steps_in_fit(
train_dataloader=dl,
eval_dataloader=dl,
epochs=4,
max_steps=None,
max_train_steps_per_epoch=None,
max_eval_steps_per_epoch=None,
eval_every_n_steps=10,
eval_every_n_epochs=2,
),
None, # if the returned number of training steps is None, we return None
)
47 changes: 47 additions & 0 deletions torchtnt/utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,50 @@ def estimated_steps_in_loop(
return min(total_steps, max_steps)

return total_steps or max_steps


def estimated_steps_in_fit(
*,
train_dataloader: Iterable[object],
eval_dataloader: Iterable[object],
epochs: Optional[int],
max_steps: Optional[int],
max_train_steps_per_epoch: Optional[int],
max_eval_steps_per_epoch: Optional[int],
eval_every_n_steps: Optional[int],
eval_every_n_epochs: Optional[int],
) -> Optional[int]:
"""
Estimate the total number of steps for fit run.
If the number of training/eval steps couldn't be calculated, None is returned.
"""
training_steps = estimated_steps_in_loop(
train_dataloader,
max_steps=max_steps,
max_steps_per_epoch=max_train_steps_per_epoch,
epochs=epochs,
)
if not training_steps:
return None

if not eval_every_n_steps and not eval_every_n_epochs:
return training_steps

number_of_eval_steps_per_eval_epoch = estimated_steps_in_loop(
eval_dataloader,
max_steps=None,
max_steps_per_epoch=max_eval_steps_per_epoch,
epochs=1,
)
if not number_of_eval_steps_per_eval_epoch:
return None

total_eval_epochs = 0
if eval_every_n_epochs and epochs:
total_eval_epochs += epochs // eval_every_n_epochs

if eval_every_n_steps:
total_eval_epochs += training_steps // eval_every_n_steps

return training_steps + total_eval_epochs * number_of_eval_steps_per_eval_epoch

0 comments on commit 3c99e52

Please sign in to comment.