Skip to content

Commit

Permalink
progress util for steps in loop
Browse files Browse the repository at this point in the history
Differential Revision: D49241644

fbshipit-source-id: c431084634df8eaa3dbf8b04584ea55ee193004e
  • Loading branch information
galrotem authored and facebook-github-bot committed Sep 14, 2023
1 parent a274b65 commit f0d3729
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 1 deletion.
83 changes: 82 additions & 1 deletion tests/utils/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from torchtnt.framework._test_utils import generate_random_dataloader

from torchtnt.utils.progress import estimated_steps_in_epoch, Progress
from torchtnt.utils.progress import (
estimated_steps_in_epoch,
estimated_steps_in_loop,
Progress,
)


class ProgressTest(unittest.TestCase):
Expand Down Expand Up @@ -82,3 +86,80 @@ def test_estimated_steps_in_epoch(self) -> None:
),
dataloader_size,
)

def test_estimated_steps_in_loop(self) -> None:
dataset_len = 10
batch_size = 2
dataloader = generate_random_dataloader(
num_samples=dataset_len, input_dim=2, batch_size=batch_size
)

self.assertEqual(
estimated_steps_in_loop(
dataloader,
max_steps=20,
max_steps_per_epoch=6,
epochs=3,
),
15, # 5 steps per epoch because the dataset would be exhausted after that
)

self.assertEqual(
estimated_steps_in_loop(
dataloader,
max_steps=20,
max_steps_per_epoch=4,
epochs=3,
),
12, # 4 steps per epoch, not exhausting all samples
)

self.assertEqual(
estimated_steps_in_loop(
dataloader,
max_steps=8,
max_steps_per_epoch=6,
epochs=3,
),
8, # we finish in the 'middle' of an epoch because of max_steps
)

self.assertEqual(
estimated_steps_in_loop(
dataloader,
max_steps=None,
max_steps_per_epoch=3,
epochs=3,
),
9, # when max_steps is none, we use epochs
)

self.assertEqual(
estimated_steps_in_loop(
dataloader,
max_steps=None,
max_steps_per_epoch=None,
epochs=3,
),
15, # when max_steps is none, we use epochs
)

self.assertEqual(
estimated_steps_in_loop(
dataloader,
max_steps=7,
max_steps_per_epoch=5,
epochs=None,
),
7, # when epoch is none, we use max_steps
)

self.assertEqual(
estimated_steps_in_loop(
dataloader,
max_steps=None,
max_steps_per_epoch=4,
epochs=None,
),
None,
)
35 changes: 35 additions & 0 deletions torchtnt/utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,38 @@ def estimated_steps_in_epoch(
elif max_steps_per_epoch:
total = min(total, max_steps_per_epoch)
return total


def estimated_steps_in_loop(
dataloader: Iterable[object],
*,
max_steps: Optional[int],
max_steps_per_epoch: Optional[int],
epochs: Optional[int],
) -> Optional[int]:
"""
Estimate the total number of steps for the current loop.
A return value of None indicates that the number of steps couldn't be estimated.
"""

if not max_steps and not epochs:
return None

if not epochs:
return max_steps

total_steps = None
steps_per_epoch = estimated_steps_in_epoch(
dataloader,
num_steps_completed=0,
max_steps=max_steps,
max_steps_per_epoch=max_steps_per_epoch,
)
if steps_per_epoch != float("inf"):
total_steps = int(steps_per_epoch) * epochs

if total_steps and max_steps:
return min(total_steps, max_steps)

return total_steps or max_steps

0 comments on commit f0d3729

Please sign in to comment.