Skip to content

Commit

Permalink
[small] format composability.md (#517)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #473
* __->__ #517
* #516

Ran `pre-commit run --all-files`
  • Loading branch information
H-Huang authored Aug 12, 2024
1 parent a4bc948 commit a47a5a9
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions docs/composability.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ When applying Pipeline Parallelism, you will have to construct nn.Module objects
Most likely, you can write your model in such a way that the top-level nn.Module owns a sequence of child modules that it calls during forward, delegating most of the complexity to the child module forwards. If you can reduce your top level forward to mostly a for-loop over child module calls, then you'll simplify the pipeline-partitioning task to choosing the set of submodules to keep per stage. If you have non-trivial logic in the top-level forward, you'll have to find a way to patch that logic back onto the resulting pipeline stage model, which can be annoying.

example ([PR #321](https://github.com/pytorch/torchtitan/pull/321)):
we used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.
we used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.

example ([PR #322])https://github.com/pytorch/torchtitan/pull/322)): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.

Expand All @@ -20,4 +20,3 @@ Initializing the pipeline-parallel model is challenging becuase we assume the mo
For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to `torch.pipelining`.

One issue with seed checkpoints is that we rely on initializing _every_ model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in `train.py` after pipeline splitting. `freqs_cis` was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.

0 comments on commit a47a5a9

Please sign in to comment.