diff --git a/tests/utils/data/test_multi_dataloader.py b/tests/utils/data/test_multi_dataloader.py index c9f434128b..a3ad13da95 100644 --- a/tests/utils/data/test_multi_dataloader.py +++ b/tests/utils/data/test_multi_dataloader.py @@ -8,7 +8,7 @@ import random import unittest from collections import Counter -from typing import Any, Dict, List, Mapping +from typing import Any, Dict, Iterator, List, Mapping import torch from torch.utils.data import DataLoader, Dataset @@ -427,3 +427,74 @@ def test_in_order_with_repetitions(self) -> None: # Raises StopIteration after all exhausted with self.assertRaises(StopIteration): batch = next(multi_dataloader) + + def test_state_dict_load_state_dict(self) -> None: + class DummyIterable: + def __init__(self, vals: List[int]) -> None: + self.vals = vals + # Start at -1 since an iterator is generated when the MultiDataLoader is constructed + # while checking for missing data + self.iter_count = -1 + + def __iter__(self) -> Iterator[int]: + self.iter_count += 1 + return iter(self.vals) + + def state_dict(self) -> Dict[str, Any]: + return {"iter_count": self.iter_count} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.iter_count = state_dict["iter_count"] + + iterable_1 = DummyIterable([1, 2, 3]) + iterable_2 = DummyIterable([4, 5, 6]) + # Add an iterable which does not implement the stateful protocol + iterable_3 = [7, 8, 9] + + multi_dataloader = MultiDataLoader( + {"foo": iterable_1, "bar": iterable_2, "baz": iterable_3}, InOrder() + ) + + # Generate state dict from initial state + original_state_dict = multi_dataloader.state_dict() + + # Confirm keys are appropriately set + self.assertIn("foo", original_state_dict) + self.assertIn("iter_count", original_state_dict["foo"]) + self.assertEqual(0, original_state_dict["foo"]["iter_count"]) + self.assertIn("bar", original_state_dict) + self.assertIn("iter_count", original_state_dict["bar"]) + self.assertEqual(0, original_state_dict["bar"]["iter_count"]) + self.assertNotIn("baz", original_state_dict) + + for _ in multi_dataloader: + pass + + self.assertEqual(multi_dataloader.individual_dataloaders["foo"].iter_count, 1) + self.assertEqual(multi_dataloader.individual_dataloaders["bar"].iter_count, 1) + + new_state_dict = multi_dataloader.state_dict() + + # Load state dict to reset to initial state + multi_dataloader.load_state_dict(original_state_dict) + self.assertEqual(multi_dataloader.individual_dataloaders["foo"].iter_count, 0) + self.assertEqual(multi_dataloader.individual_dataloaders["bar"].iter_count, 0) + + # instantiate a new multi-dataloader with a new different name + new_multi_dataloader = MultiDataLoader( + { + "foo": DummyIterable([1, 2, 3]), + "qux": DummyIterable([4, 5, 6]), + "baz": [7, 8, 9], + }, + InOrder(), + ) + new_multi_dataloader.load_state_dict(new_state_dict) + # foo's count should be loaded correctly + self.assertEqual( + new_multi_dataloader.individual_dataloaders["foo"].iter_count, 1 + ) + # qux's iter_count should still be 0 because it was not in the original state dict + self.assertEqual( + new_multi_dataloader.individual_dataloaders["qux"].iter_count, 0 + ) diff --git a/torchtnt/utils/data/multi_dataloader.py b/torchtnt/utils/data/multi_dataloader.py index ae183a0ce9..f92ce8b3ca 100644 --- a/torchtnt/utils/data/multi_dataloader.py +++ b/torchtnt/utils/data/multi_dataloader.py @@ -15,6 +15,7 @@ DataIterationStrategyRegistry, MultiIterator, ) +from torchtnt.utils.stateful import Stateful if TYPE_CHECKING: from torch.utils.data import DataLoader @@ -64,7 +65,7 @@ def __init__( ) def __iter__(self) -> Iterator[Dict[str, Any]]: - """Iterator functions for the collection of dataloaders + """Iterator functions for the collection of dataloaders. Returns: a newly created iterator based on DataIterationStrategy @@ -80,3 +81,36 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: iteration_strategy=self.iteration_strategy, ) return self.iterator + + def state_dict(self) -> Dict[str, Any]: + """Return an aggregated state dict based on individual dataloaders. + + The state dict is keyed off the names provided by ``individual_dataloaders``. + + Note: + Only states from dataloaders that implement the :class:`~torchtnt.utils.stateful.Stateful` protocol are included in the returned state dict. + """ + state_dict = {} + for name, dl in self.individual_dataloaders.items(): + if isinstance(dl, Stateful): + state_dict[name] = dl.state_dict() + + return state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Loads aggregated state dict based on individual dataloaders. + + The provided state dict should be keyed off the names provided by ``individual_dataloaders``. + + Note: + Only states from dataloaders that implement the :class:`~torchtnt.utils.stateful.Stateful` protocol are loaded. + """ + for name, dl in self.individual_dataloaders.items(): + if isinstance(dl, Stateful): + contents = state_dict.get(name, None) + if contents is None: + logger.warning( + f"Skipping loading state dict for dataloader {name} as there is no corresponding entry in the state dict" + ) + continue + dl.load_state_dict(contents)