Skip to content

Commit

Permalink
Enable creation of predict checkpoint paths (#907)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #907

Reviewed By: anshulverma, JKSenthil

Differential Revision: D63013010
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Oct 2, 2024
1 parent d86828b commit bee9db3
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 9 deletions.
60 changes: 60 additions & 0 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,32 @@ def test_create_checkpoint_path(self) -> None:
)
self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_foo=1.0")

# evaluation only
ckpt = CheckpointPath(
"foo",
epoch=0,
step={Phase.EVALUATE: 1},
)
self.assertEqual(ckpt.path, "foo/epoch_0_eval_step_1")

# prediction only
ckpt = CheckpointPath(
"foo",
epoch=0,
step={Phase.PREDICT: 1},
)
self.assertEqual(ckpt.path, "foo/epoch_0_predict_step_1")

# all phases - not expected but should work
ckpt = CheckpointPath(
"foo",
epoch=0,
step={Phase.TRAIN: 1, Phase.EVALUATE: 1, Phase.PREDICT: 1},
)
self.assertEqual(
ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_predict_step_1"
)

# nan metric value
with self.assertRaisesRegex(
ValueError,
Expand All @@ -90,6 +116,8 @@ def test_from_str(self) -> None:
"foo/epoch_2.6_step_23",
"foo/epoch_3_pred_step_3",
"foo/epoch_3__step_3",
"foo/epoch_2_predict_step_2_eval_step_1",
"foo/epoch_2_predict_step_3.2",
]
for path in malformed_paths:
with self.assertRaisesRegex(
Expand All @@ -110,6 +138,15 @@ def test_from_str(self) -> None:
"foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0)
),
),
(
"foo/epoch_14_step_3_train_loss=15.0",
CheckpointPath(
"foo",
epoch=14,
step={Phase.NONE: 3},
metric_data=MetricData("train_loss", 15.0),
),
),
(
"foo/epoch_14_step_3_loss=-27.35",
CheckpointPath(
Expand All @@ -122,6 +159,23 @@ def test_from_str(self) -> None:
"/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35)
),
),
(
"foo/epoch_2_eval_step_23",
CheckpointPath("foo", epoch=2, step={Phase.EVALUATE: 23}),
),
(
"foo/epoch_14_predict_step_5",
CheckpointPath("foo", epoch=14, step={Phase.PREDICT: 5}),
),
(
"foo/epoch_14_train_step_3_eval_loss=0.1",
CheckpointPath(
"foo",
epoch=14,
step={Phase.TRAIN: 3},
metric_data=MetricData("eval_loss", 0.1),
),
),
(
"foo/bar/epoch_23_step_31_mean_loss_squared=0.0",
CheckpointPath(
Expand Down Expand Up @@ -266,6 +320,12 @@ def test_compare_by_recency(self) -> None:
self.assertTrue(eval_only < multiphase_2)
self.assertTrue(multiphase_2 < multiphase_3)

predict_1 = CheckpointPath("foo", epoch=3, step={Phase.PREDICT: 10})
predict_2 = CheckpointPath("foo", epoch=4, step={Phase.PREDICT: 10})
predict_3 = CheckpointPath("foo", epoch=4, step={Phase.PREDICT: 20})
self.assertTrue(predict_1 < predict_2)
self.assertTrue(predict_2 < predict_3)

def test_compare_by_optimality(self) -> None:
# not both metric aware
ckpt1 = CheckpointPath("foo", epoch=0, step=1)
Expand Down
35 changes: 26 additions & 9 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Phase(Enum):
NONE = 0 # Only used for backwards compatibility
TRAIN = 1
EVALUATE = 2
PREDICT = 3


@total_ordering
Expand Down Expand Up @@ -81,7 +82,7 @@ class CheckpointPath:
)

PHASE_AWARE_REGEX: Pattern = re.compile(
r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_predict_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
)

def __init__(
Expand Down Expand Up @@ -142,8 +143,9 @@ def _populate_from_str(self, checkpoint_path: str) -> None:
Raises:
ValueError: If the path is malformed (either non-parsable, or contains wrong data types)
"""
is_phase_aware = (
"train_step" in checkpoint_path or "eval_step" in checkpoint_path
is_phase_aware = any(
phase in checkpoint_path
for phase in ["train_step", "eval_step", "predict_step"]
)
regex = self.PHASE_AWARE_REGEX if is_phase_aware else self.PHASE_NAIVE_REGEX
path_match = regex.match(checkpoint_path)
Expand All @@ -155,13 +157,22 @@ def _populate_from_str(self, checkpoint_path: str) -> None:
try:
step_mapping: Dict[Phase, int] = {}
if is_phase_aware:
dirpath, epoch, train_steps, eval_steps, metric_name, metric_value = (
path_match.groups()
)
(
dirpath,
epoch,
train_steps,
eval_steps,
predict_steps,
metric_name,
metric_value,
) = path_match.groups()

if train_steps is not None:
step_mapping[Phase.TRAIN] = int(train_steps)
if eval_steps is not None:
step_mapping[Phase.EVALUATE] = int(eval_steps)
if predict_steps is not None:
step_mapping[Phase.PREDICT] = int(predict_steps)

else:
dirpath, epoch, naive_steps, metric_name, metric_value = (
Expand Down Expand Up @@ -200,6 +211,8 @@ def path(self) -> str:
name += f"_train_step_{self.step[Phase.TRAIN]}"
if Phase.EVALUATE in self.step:
name += f"_eval_step_{self.step[Phase.EVALUATE]}"
if Phase.PREDICT in self.step:
name += f"_predict_step_{self.step[Phase.PREDICT]}"

if self.metric_data:
name += f"_{self.metric_data.name}={self.metric_data.value}"
Expand Down Expand Up @@ -240,9 +253,13 @@ def newer_than(self, other: "CheckpointPath") -> bool:
# Otherwise, compare first by eval and then train steps
return self._get_phase_steps() > other._get_phase_steps()

def _get_phase_steps(self) -> Tuple[int, int]:
"""Tuple with the phase steps ordered by phase priority in comparison (first eval, then train)."""
return self.step.get(Phase.EVALUATE, 0), self.step.get(Phase.TRAIN, 0)
def _get_phase_steps(self) -> Tuple[int, ...]:
"""Tuple with the phase steps ordered by phase priority in comparison (predict, eval, train)."""
return (
self.step.get(Phase.PREDICT, 0),
self.step.get(Phase.EVALUATE, 0),
self.step.get(Phase.TRAIN, 0),
)

def more_optimal_than(
self, other: "CheckpointPath", mode: Literal["min", "max"]
Expand Down

0 comments on commit bee9db3

Please sign in to comment.