Skip to content

Commit

Permalink
Add state_dict/load_state_dict support for MultiDataloader (#452)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #452

As title - MultiDataloader now supports the stateful protocol for checkpoint saving/loading. The dataloader generates a state dict based on the names of `individual_dataloaders` and if the underlying dataloaders themselves are Stateful.

Reviewed By: daniellepintz

Differential Revision: D47409954

fbshipit-source-id: d3700c5672ce950e403438cfb2b28f0c1ff11fbf
  • Loading branch information
ananthsub authored and facebook-github-bot committed Jul 12, 2023
1 parent 5d7884f commit 1de9894
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
73 changes: 72 additions & 1 deletion tests/utils/data/test_multi_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
import unittest
from collections import Counter
from typing import Any, Dict, List, Mapping
from typing import Any, Dict, Iterator, List, Mapping

import torch
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -427,3 +427,74 @@ def test_in_order_with_repetitions(self) -> None:
# Raises StopIteration after all exhausted
with self.assertRaises(StopIteration):
batch = next(multi_dataloader)

def test_state_dict_load_state_dict(self) -> None:
class DummyIterable:
def __init__(self, vals: List[int]) -> None:
self.vals = vals
# Start at -1 since an iterator is generated when the MultiDataLoader is constructed
# while checking for missing data
self.iter_count = -1

def __iter__(self) -> Iterator[int]:
self.iter_count += 1
return iter(self.vals)

def state_dict(self) -> Dict[str, Any]:
return {"iter_count": self.iter_count}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.iter_count = state_dict["iter_count"]

iterable_1 = DummyIterable([1, 2, 3])
iterable_2 = DummyIterable([4, 5, 6])
# Add an iterable which does not implement the stateful protocol
iterable_3 = [7, 8, 9]

multi_dataloader = MultiDataLoader(
{"foo": iterable_1, "bar": iterable_2, "baz": iterable_3}, InOrder()
)

# Generate state dict from initial state
original_state_dict = multi_dataloader.state_dict()

# Confirm keys are appropriately set
self.assertIn("foo", original_state_dict)
self.assertIn("iter_count", original_state_dict["foo"])
self.assertEqual(0, original_state_dict["foo"]["iter_count"])
self.assertIn("bar", original_state_dict)
self.assertIn("iter_count", original_state_dict["bar"])
self.assertEqual(0, original_state_dict["bar"]["iter_count"])
self.assertNotIn("baz", original_state_dict)

for _ in multi_dataloader:
pass

self.assertEqual(multi_dataloader.individual_dataloaders["foo"].iter_count, 1)
self.assertEqual(multi_dataloader.individual_dataloaders["bar"].iter_count, 1)

new_state_dict = multi_dataloader.state_dict()

# Load state dict to reset to initial state
multi_dataloader.load_state_dict(original_state_dict)
self.assertEqual(multi_dataloader.individual_dataloaders["foo"].iter_count, 0)
self.assertEqual(multi_dataloader.individual_dataloaders["bar"].iter_count, 0)

# instantiate a new multi-dataloader with a new different name
new_multi_dataloader = MultiDataLoader(
{
"foo": DummyIterable([1, 2, 3]),
"qux": DummyIterable([4, 5, 6]),
"baz": [7, 8, 9],
},
InOrder(),
)
new_multi_dataloader.load_state_dict(new_state_dict)
# foo's count should be loaded correctly
self.assertEqual(
new_multi_dataloader.individual_dataloaders["foo"].iter_count, 1
)
# qux's iter_count should still be 0 because it was not in the original state dict
self.assertEqual(
new_multi_dataloader.individual_dataloaders["qux"].iter_count, 0
)
36 changes: 35 additions & 1 deletion torchtnt/utils/data/multi_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DataIterationStrategyRegistry,
MultiIterator,
)
from torchtnt.utils.stateful import Stateful

if TYPE_CHECKING:
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
)

def __iter__(self) -> Iterator[Dict[str, Any]]:
"""Iterator functions for the collection of dataloaders
"""Iterator functions for the collection of dataloaders.
Returns:
a newly created iterator based on DataIterationStrategy
Expand All @@ -80,3 +81,36 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
iteration_strategy=self.iteration_strategy,
)
return self.iterator

def state_dict(self) -> Dict[str, Any]:
"""Return an aggregated state dict based on individual dataloaders.
The state dict is keyed off the names provided by ``individual_dataloaders``.
Note:
Only states from dataloaders that implement the :class:`~torchtnt.utils.stateful.Stateful` protocol are included in the returned state dict.
"""
state_dict = {}
for name, dl in self.individual_dataloaders.items():
if isinstance(dl, Stateful):
state_dict[name] = dl.state_dict()

return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads aggregated state dict based on individual dataloaders.
The provided state dict should be keyed off the names provided by ``individual_dataloaders``.
Note:
Only states from dataloaders that implement the :class:`~torchtnt.utils.stateful.Stateful` protocol are loaded.
"""
for name, dl in self.individual_dataloaders.items():
if isinstance(dl, Stateful):
contents = state_dict.get(name, None)
if contents is None:
logger.warning(
f"Skipping loading state dict for dataloader {name} as there is no corresponding entry in the state dict"
)
continue
dl.load_state_dict(contents)

0 comments on commit 1de9894

Please sign in to comment.