diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index 046e009236..1cd068abc2 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -29,6 +29,7 @@ BestCheckpointConfig, CheckpointManager, CheckpointPath, + does_checkpoint_exist, get_best_checkpoint_path, get_checkpoint_dirpaths, get_latest_checkpoint_path, @@ -1419,6 +1420,27 @@ def test_does_checkpoint_metadata_exist(self) -> None: ) ) + def test_does_checkpoint_exist(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_1 = os.path.join(temp_dir, "checkpoint_1") + os.mkdir(ckpt_1) + + self.assertFalse(does_checkpoint_exist(ckpt_1, metadata_fname=None)) + + with open(os.path.join(ckpt_1, ".metadata"), "w"): + pass + + self.assertFalse(does_checkpoint_exist(ckpt_1, metadata_fname="manifest")) + self.assertTrue(does_checkpoint_exist(ckpt_1, metadata_fname=".metadata")) + self.assertTrue( + does_checkpoint_exist(ckpt_1, metadata_fname=["manifest", ".metadata"]) + ) + self.assertFalse( + does_checkpoint_exist( + ckpt_1, metadata_fname=["manifest", ".state_dict_info"] + ) + ) + class MyValLossUnit(TrainUnit[Batch]): def __init__(self) -> None: diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index bc352d132c..02420d973a 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -539,21 +539,18 @@ def append_checkpoint(self, ckpt: CheckpointPath) -> None: # No metric tracked, most recents goes last self._ckpt_paths.append(ckpt) - @rank_zero_read_and_broadcast def does_checkpoint_exist( - self, ckpt: CheckpointPath, process_group: Optional[dist.ProcessGroup] = None + self, + ckpt: CheckpointPath, + process_group: Optional[dist.ProcessGroup] = None, ) -> bool: """ Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory. If the checkpointer doesn't have a metadata file, this function will always return False. Check is executed in rank 0, but result is broadcasted to all ranks. """ - if not self._metadata_fnames: - return False - - fs, _ = url_to_fs(self.dirpath) - return any( - _metadata_exists(fs, ckpt.path, fname) for fname in self._metadata_fnames + return does_checkpoint_exist( + ckpt.path, self._metadata_fnames, process_group=process_group ) @staticmethod @@ -596,6 +593,33 @@ def remove_checkpoint(self) -> None: ) +@rank_zero_read_and_broadcast +def does_checkpoint_exist( + ckpt_path: str, + metadata_fname: Union[str, List[str]], + process_group: Optional[dist.ProcessGroup] = None, +) -> bool: + """ + Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory. + Will return False if the metadata_fname is None. Check is executed in rank 0, but + result is broadcasted to all ranks. + + Args: + ckpt: The checkpoint to check. + metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present. + process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used. + """ + if not metadata_fname: + return False + else: + metadata_fnames = ( + [metadata_fname] if isinstance(metadata_fname, str) else metadata_fname + ) + + fs, _ = url_to_fs(ckpt_path) + return any(_metadata_exists(fs, ckpt_path, fname) for fname in metadata_fnames) + + @rank_zero_read_and_broadcast def get_latest_checkpoint_path( dirpath: str,