From 1064d10348019204567dbda011c4972d87fb5b2b Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Tue, 1 Oct 2024 21:14:30 -0700 Subject: [PATCH] Enable creation of predict checkpoint paths (#907) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/907 Reviewed By: anshulverma, JKSenthil Differential Revision: D63013010 fbshipit-source-id: f7e872ebd4b65d74f312a0dd25d220c4931af658 --- tests/utils/test_checkpoint.py | 60 ++++++++++++++++++++++++++++++++++ torchtnt/utils/checkpoint.py | 35 +++++++++++++++----- 2 files changed, 86 insertions(+), 9 deletions(-) diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index 690cc4d0a9..046e009236 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -66,6 +66,32 @@ def test_create_checkpoint_path(self) -> None: ) self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_foo=1.0") + # evaluation only + ckpt = CheckpointPath( + "foo", + epoch=0, + step={Phase.EVALUATE: 1}, + ) + self.assertEqual(ckpt.path, "foo/epoch_0_eval_step_1") + + # prediction only + ckpt = CheckpointPath( + "foo", + epoch=0, + step={Phase.PREDICT: 1}, + ) + self.assertEqual(ckpt.path, "foo/epoch_0_predict_step_1") + + # all phases - not expected but should work + ckpt = CheckpointPath( + "foo", + epoch=0, + step={Phase.TRAIN: 1, Phase.EVALUATE: 1, Phase.PREDICT: 1}, + ) + self.assertEqual( + ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_predict_step_1" + ) + # nan metric value with self.assertRaisesRegex( ValueError, @@ -90,6 +116,8 @@ def test_from_str(self) -> None: "foo/epoch_2.6_step_23", "foo/epoch_3_pred_step_3", "foo/epoch_3__step_3", + "foo/epoch_2_predict_step_2_eval_step_1", + "foo/epoch_2_predict_step_3.2", ] for path in malformed_paths: with self.assertRaisesRegex( @@ -110,6 +138,15 @@ def test_from_str(self) -> None: "foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0) ), ), + ( + "foo/epoch_14_step_3_train_loss=15.0", + CheckpointPath( + "foo", + epoch=14, + step={Phase.NONE: 3}, + metric_data=MetricData("train_loss", 15.0), + ), + ), ( "foo/epoch_14_step_3_loss=-27.35", CheckpointPath( @@ -122,6 +159,23 @@ def test_from_str(self) -> None: "/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35) ), ), + ( + "foo/epoch_2_eval_step_23", + CheckpointPath("foo", epoch=2, step={Phase.EVALUATE: 23}), + ), + ( + "foo/epoch_14_predict_step_5", + CheckpointPath("foo", epoch=14, step={Phase.PREDICT: 5}), + ), + ( + "foo/epoch_14_train_step_3_eval_loss=0.1", + CheckpointPath( + "foo", + epoch=14, + step={Phase.TRAIN: 3}, + metric_data=MetricData("eval_loss", 0.1), + ), + ), ( "foo/bar/epoch_23_step_31_mean_loss_squared=0.0", CheckpointPath( @@ -266,6 +320,12 @@ def test_compare_by_recency(self) -> None: self.assertTrue(eval_only < multiphase_2) self.assertTrue(multiphase_2 < multiphase_3) + predict_1 = CheckpointPath("foo", epoch=3, step={Phase.PREDICT: 10}) + predict_2 = CheckpointPath("foo", epoch=4, step={Phase.PREDICT: 10}) + predict_3 = CheckpointPath("foo", epoch=4, step={Phase.PREDICT: 20}) + self.assertTrue(predict_1 < predict_2) + self.assertTrue(predict_2 < predict_3) + def test_compare_by_optimality(self) -> None: # not both metric aware ckpt1 = CheckpointPath("foo", epoch=0, step=1) diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 5129fc9e9f..bc352d132c 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -53,6 +53,7 @@ class Phase(Enum): NONE = 0 # Only used for backwards compatibility TRAIN = 1 EVALUATE = 2 + PREDICT = 3 @total_ordering @@ -81,7 +82,7 @@ class CheckpointPath: ) PHASE_AWARE_REGEX: Pattern = re.compile( - r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$" + r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_predict_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$" ) def __init__( @@ -142,8 +143,9 @@ def _populate_from_str(self, checkpoint_path: str) -> None: Raises: ValueError: If the path is malformed (either non-parsable, or contains wrong data types) """ - is_phase_aware = ( - "train_step" in checkpoint_path or "eval_step" in checkpoint_path + is_phase_aware = any( + phase in checkpoint_path + for phase in ["train_step", "eval_step", "predict_step"] ) regex = self.PHASE_AWARE_REGEX if is_phase_aware else self.PHASE_NAIVE_REGEX path_match = regex.match(checkpoint_path) @@ -155,13 +157,22 @@ def _populate_from_str(self, checkpoint_path: str) -> None: try: step_mapping: Dict[Phase, int] = {} if is_phase_aware: - dirpath, epoch, train_steps, eval_steps, metric_name, metric_value = ( - path_match.groups() - ) + ( + dirpath, + epoch, + train_steps, + eval_steps, + predict_steps, + metric_name, + metric_value, + ) = path_match.groups() + if train_steps is not None: step_mapping[Phase.TRAIN] = int(train_steps) if eval_steps is not None: step_mapping[Phase.EVALUATE] = int(eval_steps) + if predict_steps is not None: + step_mapping[Phase.PREDICT] = int(predict_steps) else: dirpath, epoch, naive_steps, metric_name, metric_value = ( @@ -200,6 +211,8 @@ def path(self) -> str: name += f"_train_step_{self.step[Phase.TRAIN]}" if Phase.EVALUATE in self.step: name += f"_eval_step_{self.step[Phase.EVALUATE]}" + if Phase.PREDICT in self.step: + name += f"_predict_step_{self.step[Phase.PREDICT]}" if self.metric_data: name += f"_{self.metric_data.name}={self.metric_data.value}" @@ -240,9 +253,13 @@ def newer_than(self, other: "CheckpointPath") -> bool: # Otherwise, compare first by eval and then train steps return self._get_phase_steps() > other._get_phase_steps() - def _get_phase_steps(self) -> Tuple[int, int]: - """Tuple with the phase steps ordered by phase priority in comparison (first eval, then train).""" - return self.step.get(Phase.EVALUATE, 0), self.step.get(Phase.TRAIN, 0) + def _get_phase_steps(self) -> Tuple[int, ...]: + """Tuple with the phase steps ordered by phase priority in comparison (predict, eval, train).""" + return ( + self.step.get(Phase.PREDICT, 0), + self.step.get(Phase.EVALUATE, 0), + self.step.get(Phase.TRAIN, 0), + ) def more_optimal_than( self, other: "CheckpointPath", mode: Literal["min", "max"]