From f0d3729e515668a931fa48f77ac3da4dde2ed9cd Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Thu, 14 Sep 2023 15:01:26 -0700 Subject: [PATCH] progress util for steps in loop Differential Revision: D49241644 fbshipit-source-id: c431084634df8eaa3dbf8b04584ea55ee193004e --- tests/utils/test_progress.py | 83 +++++++++++++++++++++++++++++++++++- torchtnt/utils/progress.py | 35 +++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_progress.py b/tests/utils/test_progress.py index b24138f220..2aae976878 100644 --- a/tests/utils/test_progress.py +++ b/tests/utils/test_progress.py @@ -9,7 +9,11 @@ from torchtnt.framework._test_utils import generate_random_dataloader -from torchtnt.utils.progress import estimated_steps_in_epoch, Progress +from torchtnt.utils.progress import ( + estimated_steps_in_epoch, + estimated_steps_in_loop, + Progress, +) class ProgressTest(unittest.TestCase): @@ -82,3 +86,80 @@ def test_estimated_steps_in_epoch(self) -> None: ), dataloader_size, ) + + def test_estimated_steps_in_loop(self) -> None: + dataset_len = 10 + batch_size = 2 + dataloader = generate_random_dataloader( + num_samples=dataset_len, input_dim=2, batch_size=batch_size + ) + + self.assertEqual( + estimated_steps_in_loop( + dataloader, + max_steps=20, + max_steps_per_epoch=6, + epochs=3, + ), + 15, # 5 steps per epoch because the dataset would be exhausted after that + ) + + self.assertEqual( + estimated_steps_in_loop( + dataloader, + max_steps=20, + max_steps_per_epoch=4, + epochs=3, + ), + 12, # 4 steps per epoch, not exhausting all samples + ) + + self.assertEqual( + estimated_steps_in_loop( + dataloader, + max_steps=8, + max_steps_per_epoch=6, + epochs=3, + ), + 8, # we finish in the 'middle' of an epoch because of max_steps + ) + + self.assertEqual( + estimated_steps_in_loop( + dataloader, + max_steps=None, + max_steps_per_epoch=3, + epochs=3, + ), + 9, # when max_steps is none, we use epochs + ) + + self.assertEqual( + estimated_steps_in_loop( + dataloader, + max_steps=None, + max_steps_per_epoch=None, + epochs=3, + ), + 15, # when max_steps is none, we use epochs + ) + + self.assertEqual( + estimated_steps_in_loop( + dataloader, + max_steps=7, + max_steps_per_epoch=5, + epochs=None, + ), + 7, # when epoch is none, we use max_steps + ) + + self.assertEqual( + estimated_steps_in_loop( + dataloader, + max_steps=None, + max_steps_per_epoch=4, + epochs=None, + ), + None, + ) diff --git a/torchtnt/utils/progress.py b/torchtnt/utils/progress.py index 4f075d999c..528824623d 100644 --- a/torchtnt/utils/progress.py +++ b/torchtnt/utils/progress.py @@ -84,3 +84,38 @@ def estimated_steps_in_epoch( elif max_steps_per_epoch: total = min(total, max_steps_per_epoch) return total + + +def estimated_steps_in_loop( + dataloader: Iterable[object], + *, + max_steps: Optional[int], + max_steps_per_epoch: Optional[int], + epochs: Optional[int], +) -> Optional[int]: + """ + Estimate the total number of steps for the current loop. + + A return value of None indicates that the number of steps couldn't be estimated. + """ + + if not max_steps and not epochs: + return None + + if not epochs: + return max_steps + + total_steps = None + steps_per_epoch = estimated_steps_in_epoch( + dataloader, + num_steps_completed=0, + max_steps=max_steps, + max_steps_per_epoch=max_steps_per_epoch, + ) + if steps_per_epoch != float("inf"): + total_steps = int(steps_per_epoch) * epochs + + if total_steps and max_steps: + return min(total_steps, max_steps) + + return total_steps or max_steps