Skip to content

Commit

Permalink
Use CheckpointPath in get_x_checkpoint functions
Browse files Browse the repository at this point in the history
Reviewed By: JKSenthil

Differential Revision: D56427223
  • Loading branch information
diego-urgell authored and facebook-github-bot committed May 1, 2024
1 parent e1135d6 commit e9182ec
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 122 deletions.
89 changes: 58 additions & 31 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def test_latest_checkpoint_path(self) -> None:
path_4 = os.path.join(temp_dir, "epoch_700")
os.mkdir(path_4)
self.assertEqual(
get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2
get_latest_checkpoint_path(temp_dir, METADATA_FNAME),
path_2,
)

@skip_if_not_distributed
Expand Down Expand Up @@ -284,7 +285,8 @@ def _latest_checkpoint_path_distributed() -> None:
expected_path = path_container[0]
tc.assertIsNotNone(expected_path)
tc.assertEqual(
get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path
get_latest_checkpoint_path(temp_dir, METADATA_FNAME),
expected_path,
)

if is_rank0:
Expand Down Expand Up @@ -368,7 +370,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
# compares set equality since order of returned dirpaths is not guaranteed
# in _retrieve_checkpoint_dirpaths
self.assertEqual(
set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
{
str(x)
for x in _retrieve_checkpoint_dirpaths(
temp_dir, metadata_fname=None
)
},
{os.path.join(temp_dir, path) for path in paths[:-1]},
)
self.assertEqual(
Expand All @@ -382,9 +389,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
pass

self.assertEqual(
set(
_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata")
),
{
str(x)
for x in _retrieve_checkpoint_dirpaths(
temp_dir, metadata_fname=".metadata"
)
},
{os.path.join(temp_dir, paths[2])},
)

Expand All @@ -394,30 +404,36 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
"""
with tempfile.TemporaryDirectory() as temp_dir:
paths = [
"epoch_0_step_10_val_loss=10",
"epoch_1_step_10_val_loss=5",
"epoch_0_step_10_val_loss=10.0",
"epoch_1_step_10_val_loss=5.0",
"epoch_2_step_10",
"epoch_0_step_5",
"epoch_0_step_6_train_loss=13",
"epoch_0_step_6_train_loss=13.0",
]
for path in paths:
os.mkdir(os.path.join(temp_dir, path))
# make last path a file instead of a directory
with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"):
with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3.0"), "w"):
pass

# compares set equality since order of returned dirpaths is not guaranteed
# in _retrieve_checkpoint_dirpaths
self.assertEqual(
set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
{
str(x)
for x in _retrieve_checkpoint_dirpaths(
temp_dir, metadata_fname=None
)
},
{os.path.join(temp_dir, path) for path in paths},
)
self.assertEqual(
set(
_retrieve_checkpoint_dirpaths(
{
str(x)
for x in _retrieve_checkpoint_dirpaths(
temp_dir, metadata_fname=None, metric_name="val_loss"
)
),
},
{
os.path.join(temp_dir, path) for path in paths[:2]
}, # since last path is a file
Expand All @@ -433,11 +449,12 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
pass

self.assertEqual(
set(
_retrieve_checkpoint_dirpaths(
{
str(x)
for x in _retrieve_checkpoint_dirpaths(
temp_dir, metadata_fname=".metadata", metric_name="val_loss"
)
),
},
{os.path.join(temp_dir, paths[1])},
)

Expand Down Expand Up @@ -467,7 +484,7 @@ def create_tmp_dir() -> str:
os.mkdir(path2)
torch.distributed.barrier()

ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir)
ckpt_dirpaths = [str(x) for x in get_checkpoint_dirpaths(temp_dir)]
tc = unittest.TestCase()
tc.assertEqual(set(ckpt_dirpaths), {path1, path2})

Expand All @@ -492,7 +509,7 @@ def test_get_checkpoint_dirpaths(self) -> None:
os.mkdir(path3)

self.assertEqual(
set(get_checkpoint_dirpaths(temp_dir)),
{str(x) for x in get_checkpoint_dirpaths(temp_dir)},
{path1, path2, path3},
)

Expand All @@ -505,7 +522,10 @@ def test_get_checkpoint_dirpaths(self) -> None:
os.mkdir(path3)

self.assertEqual(
set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")),
{
str(x)
for x in get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")
},
{path1, path2, path3},
)

Expand All @@ -519,20 +539,27 @@ def test_checkpoint_sorting_utils(self) -> None:
"""
Tests the sort utilities
"""
paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"]
self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]])
paths = [
"foo/epoch_1_step_20",
"foo/epoch_4_step_130",
"foo/epoch_0_step_10_val_loss=10.0",
]
ckpts = [CheckpointPath.from_str(x) for x in paths]
sorted_paths = [str(x) for x in _sort_by_recency(ckpts)]
self.assertEqual(sorted_paths, [paths[2], paths[0], paths[1]])

paths = [
"epoch_1_step_20_val_loss=0.09",
"epoch_4_step_130_val_loss=29",
"epoch_0_step_10_val_loss=10",
"foo/epoch_1_step_20_val_loss=0.09",
"foo/epoch_4_step_130_val_loss=29.0",
"foo/epoch_0_step_10_val_loss=10.0",
]
self.assertEqual(
_sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]]
)
self.assertEqual(
_sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]]
)
ckpts = [CheckpointPath.from_str(x) for x in paths]

sorted_paths = [str(x) for x in _sort_by_metric_value(ckpts, mode="min")]
self.assertEqual(sorted_paths, [paths[1], paths[2], paths[0]])

sorted_paths = [str(x) for x in _sort_by_metric_value(ckpts, mode="max")]
self.assertEqual(sorted_paths, [paths[0], paths[2], paths[1]])

def test_delete_checkpoint(self) -> None:
"""
Expand Down
7 changes: 5 additions & 2 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,14 @@ def __init__(

# sort by metric value if doing best checkpoint, else by recency
if best_checkpoint_config:
self._ckpt_dirpaths = _sort_by_metric_value(
ckpt_dirpaths = _sort_by_metric_value(
ckpt_dirpaths, mode=best_checkpoint_config.mode
)
else:
self._ckpt_dirpaths = _sort_by_recency(ckpt_dirpaths)
ckpt_dirpaths = _sort_by_recency(ckpt_dirpaths)

# TODO Remove this when using CheckpointManager
self._ckpt_dirpaths = [str(x) for x in ckpt_dirpaths]

self._process_group: Optional[dist.ProcessGroup] = None
self._setup_gloo_pg(process_group)
Expand Down
Loading

0 comments on commit e9182ec

Please sign in to comment.