Skip to content

Commit

Permalink
Extract does_checkpoint_exist into util (#906)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #906

Reviewed By: anshulverma, schwarzmx

Differential Revision: D63036738

fbshipit-source-id: 3aa053e4c788e8f42a0f03a2fa2995510836172a
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Oct 4, 2024
1 parent 6d99aae commit 1f06115
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
22 changes: 22 additions & 0 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BestCheckpointConfig,
CheckpointManager,
CheckpointPath,
does_checkpoint_exist,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
get_latest_checkpoint_path,
Expand Down Expand Up @@ -1419,6 +1420,27 @@ def test_does_checkpoint_metadata_exist(self) -> None:
)
)

def test_does_checkpoint_exist(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_1 = os.path.join(temp_dir, "checkpoint_1")
os.mkdir(ckpt_1)

self.assertFalse(does_checkpoint_exist(ckpt_1, metadata_fname=None))

with open(os.path.join(ckpt_1, ".metadata"), "w"):
pass

self.assertFalse(does_checkpoint_exist(ckpt_1, metadata_fname="manifest"))
self.assertTrue(does_checkpoint_exist(ckpt_1, metadata_fname=".metadata"))
self.assertTrue(
does_checkpoint_exist(ckpt_1, metadata_fname=["manifest", ".metadata"])
)
self.assertFalse(
does_checkpoint_exist(
ckpt_1, metadata_fname=["manifest", ".state_dict_info"]
)
)


class MyValLossUnit(TrainUnit[Batch]):
def __init__(self) -> None:
Expand Down
40 changes: 32 additions & 8 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,21 +539,18 @@ def append_checkpoint(self, ckpt: CheckpointPath) -> None:
# No metric tracked, most recents goes last
self._ckpt_paths.append(ckpt)

@rank_zero_read_and_broadcast
def does_checkpoint_exist(
self, ckpt: CheckpointPath, process_group: Optional[dist.ProcessGroup] = None
self,
ckpt: CheckpointPath,
process_group: Optional[dist.ProcessGroup] = None,
) -> bool:
"""
Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
If the checkpointer doesn't have a metadata file, this function will always return False. Check is executed in rank 0, but
result is broadcasted to all ranks.
"""
if not self._metadata_fnames:
return False

fs, _ = url_to_fs(self.dirpath)
return any(
_metadata_exists(fs, ckpt.path, fname) for fname in self._metadata_fnames
return does_checkpoint_exist(
ckpt.path, self._metadata_fnames, process_group=process_group
)

@staticmethod
Expand Down Expand Up @@ -596,6 +593,33 @@ def remove_checkpoint(self) -> None:
)


@rank_zero_read_and_broadcast
def does_checkpoint_exist(
ckpt_path: str,
metadata_fname: Union[str, List[str]],
process_group: Optional[dist.ProcessGroup] = None,
) -> bool:
"""
Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
Will return False if the metadata_fname is None. Check is executed in rank 0, but
result is broadcasted to all ranks.
Args:
ckpt: The checkpoint to check.
metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present.
process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used.
"""
if not metadata_fname:
return False
else:
metadata_fnames = (
[metadata_fname] if isinstance(metadata_fname, str) else metadata_fname
)

fs, _ = url_to_fs(ckpt_path)
return any(_metadata_exists(fs, ckpt_path, fname) for fname in metadata_fnames)


@rank_zero_read_and_broadcast
def get_latest_checkpoint_path(
dirpath: str,
Expand Down

0 comments on commit 1f06115

Please sign in to comment.