Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rescalability via IBM dataset layers #1372

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

daviswer
Copy link

Implements rescaling of checkpoints to different world sizes and numbers of workers. User specifies in advance the number of data partitions, and when saving/loading checkpoints with different total workers (must divide partition number evenly), stateful guarantees are maintained: seen data is not revisited until the next epoch.

Based off of the datasets in the corresponding IBM torchtitan PR, but uses StatefulDataLoader and DCP to manage checkpointing from the master process. Sampling and Dummy datasets are included for demo purposes. It is possible that the IBM datasets can be merged into the existing node structure.

Changes

  • Add IBM rescalable datasets and checkpointing functions to torchdata/stateful_dataloader/ibm_rescalable.py
  • Add demo script and correctness check to examples/ibm_rescaling/rescaling_demo.py

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 23, 2024
"""


def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are tail elements just truncated?

# Setup / loading flags
self.is_setup = False

def setup(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be mapped pretty easily to BaseNode.reset()

[setattr(self, flag, state_dict[self.statename(flag)]) for flag in self.state_params]


class _WrapperDataset(_StatefulDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking out loud: could we do this with mixins instead of extending the type hierarchy?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, what's the benefit of having two subclasses?

while True:
ind = self.current_reader
# Read doc
out = next(data[ind])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is StopIteration handled?

Comment on lines +483 to +490
# Convert to tensor form
out = {}
for k, v in state_dict.items():
v = torch.tensor(v)
if len(v.shape) == 0:
k = k + ".scalar"
v = v.unsqueeze(0)
out[k] = v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this done to satisfy DCP requirements?

#### ------------------------- CHECKPOINT FUNCTIONS ------------------------- ####


def __pop_dstate(state, device_mesh, placements):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should create standard utilities to get these in torchdata #1337

self.current_reader = (self.current_reader + 1) % self.n_logicals
yield out

def state_dict(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{
  my_children: [c.state_dict() for c in self.children],
  scalar_state: self.scalar, # "my_string"
  my_reshardale_state: tensor.array([1, 2, 3, 4, 5]), # 2d tensor
}

question: what happens if above state_dict gets passed to DCP? 
Answer: it will fail because torch.tensor gets called on everything?

Andrew to follow up with @pradeepfn on this

assert len(logical_shard_states) > 0, f"Worker {self.rank} owns no shards???"
# Flip list[dict[Any]] to dict[list[Any]]
state_dict = {k: [d[k] for d in logical_shard_states] for k in logical_shard_states[0].keys()}
state_dict.update(_StatefulDataset.state_dict(self))
Copy link
Contributor

@andrewkho andrewkho Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does self.current_reader need to be stored too? ie for determinism in the case where resharding doesn't happen at all

self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8))


class ScalableShardDataset(_WrapperDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ScalableShardDataset(_WrapperDataset):
Protocol based instead
class ScalableShardDataset(BaseNode[T], Reshardable):

Comment on lines +465 to +469
data = [iter(d) for d in self.data]
while True:
ind = self.current_reader
# Read doc
out = next(data[ind])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we bound the number of open iterators/filepointers/etc in some way here while still maintaining re-shardability

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly run into end-of-epoch problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the assumption of indexable files

logical_shard_states = [d.state_dict() for d in self.data]
assert len(logical_shard_states) > 0, f"Worker {self.rank} owns no shards???"
# Flip list[dict[Any]] to dict[list[Any]]
state_dict = {k: [d[k] for d in logical_shard_states] for k in logical_shard_states[0].keys()}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: len(logical_shard_state) is always 1 ? Looking at the list comprehension, it seems so. But I do not understand why. thanks.

Update:
Think I got it. The keys are same across sub-datasets, therefore, we use the logical_shard_state[0].keys() as anchor.?

writer,
)
# Write nondistributed state dict
torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth"))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also make this state part of the main checkpoint? We can use the torch.save serialization (output is bytestream), to store the state as part of the DCP checkpoint.

buff = io.BytesIO()
torch.save(state, buff) # or we can serialize individual keys in the state-dict. But no strong need.
buff.seek(0)

assume the unique key is 'trainer_dataloader_state_rank_k' -> "

update the dstate with new key -> value.
checkpoint.save(dstate)

ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "__nondist_cp_" in x])
# Check that number of loaders matches
if ckp_ws == loader.dataset.worldsize:
state = torch.load(os.path.join(path, f"__nondist_cp_{rank}.pth"))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q: what are we missing out, if we just set the;

data_loader_state= base

without considering the rescaling property of the training run ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants