Skip to content

Commit

Permalink
Update on "3d with fp8 in test runner"
Browse files Browse the repository at this point in the history
fp8 not working in CI


[ghstack-poisoned]
  • Loading branch information
H-Huang committed Sep 30, 2024
2 parents dfa588c + c8159b0 commit af18855
Show file tree
Hide file tree
Showing 19 changed files with 161 additions and 80 deletions.
34 changes: 15 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# torchtitan

`torchtitan` is currently in a pre-release state and under extensive development.
`torchtitan` is currently in a pre-release state and under extensive development. Currently we showcase pre-training **Llama 3.1**, **Llama 3**, and **Llama 2** LLMs of various sizes from scratch. To use the latest features of `torchtitan`, we recommend latest PyTorch nightly.

`torchtitan` is a proof-of-concept for Large-scale LLM training using native PyTorch. It is (and will continue to be) a repo to showcase PyTorch's latest distributed training features in a clean, minimal codebase. torchtitan is complementary to and not a replacement for any of the great large-scale LLM training codebases such as Megatron, Megablocks, LLM Foundry, Deepspeed, etc. Instead, we hope that the features showcased in torchtitan will be adopted by these codebases quickly. torchtitan is unlikely to ever grow a large community around it.

Expand All @@ -26,34 +26,30 @@ You may want to see how the model is defined or how parallelism techniques are a
* [torchtitan/parallelisms/pipeline_llama.py](torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
* [torchtitan/checkpoint.py](torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
* [torchtitan/float8.py](torchtitan/float8.py) - utils for applying Float8 techniques
* [torchtitan/models/llama/model.py](torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)

## Pre-Release Updates:
#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development.
Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly).
* [torchtitan/models/llama/model.py](torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama 2 and Llama 3 variants)

### Key features available

1. [FSDP2 with per param sharding](docs/fsdp.md)
2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including async TP)
1. [FSDP2](docs/fsdp.md) with per param sharding
2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including [async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487))
3. Selective layer and operator activation checkpointing
4. Distributed checkpointing (including async checkpointing)
5. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries)
6. Loss, GPU memory, tokens-per-second, and MFU displayed and logged via TensorBoard
7. Learning rate scheduler, meta-init, optional Fused RMSNorm into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
8. [Float8 support](docs/float8.md)
6. Loss, GPU memory, tokens-per-second, and MFU displayed and logged via [TensorBoard](#tensorboard)
7. Learning rate scheduler, meta-init, optional Fused RMSNorm
8. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md))
9. `torch.compile` support
10. All options easily configured via [toml files](train_configs/)
11. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly
10. DDP and HSDP
11. All options easily configured via [toml files](train_configs/)
12. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning

We report our [Performance](docs/performance.md) verified on 64 A100 GPUs
We report our [Performance](docs/performance.md) verified on 64/128 GPUs.


### Coming soon

1. Context Parallel
2. Pipeline Parallel (and 3D parallellism)
3. HSDP
- Pipeline Parallel (and 3D parallellism)
- Context Parallel


## Installation
Expand All @@ -74,10 +70,10 @@ Once you have confirmed access, you can run the following command to download th
```bash
# Get your HF token from https://huggingface.co/settings/tokens

# llama3 or 3.1 tokenizer.model
# Llama 3 or 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...

# llama2 tokenizer.model
# Llama 2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-2-13b-hf --hf_token=...
```

Expand Down
4 changes: 2 additions & 2 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def estimate_memory(job_config: JobConfig):
job_config.experimental.enable_compiled_autograd = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
Expand Down
2 changes: 1 addition & 1 deletion test/test_fused_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Replicate,
Shard,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand Down
40 changes: 32 additions & 8 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--training.data_parallel_shard_degree 1",
],
],
"PP 1D test 1f1b",
Expand All @@ -172,7 +172,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--training.data_parallel_shard_degree 1",
],
],
"PP 1D test gpipe",
Expand All @@ -187,7 +187,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
],
],
"PP+DP 1f1b 2D test",
Expand All @@ -201,7 +201,7 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
],
],
"PP+DP gpipe 2D test",
Expand All @@ -227,15 +227,15 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
[
"--training.steps 20",
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
],
Expand All @@ -249,7 +249,7 @@ def build_test_list():
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
"--float8.enable_float8_linear",
"--float8.enable_fsdp_float8_all_gather",
Expand Down Expand Up @@ -302,13 +302,37 @@ def build_test_list():
OverrideDefinitions(
[
[
"--training.data_parallel_type ddp",
"--training.data_parallel_shard_degree=1",
"--training.data_parallel_replicate_degree=4",
]
],
"DDP",
"ddp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
]
],
"HSDP",
"hsdp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
"--training.tensor_parallel_degree=2",
]
],
"HSDP+TP",
"hsdp+tp",
ngpu=8,
),
OverrideDefinitions(
[
[
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ModelWrapper(Stateful):
def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model

def state_dict(self) -> None:
def state_dict(self) -> Dict[str, Any]:
return {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
}
Expand All @@ -107,7 +107,7 @@ def __init__(
self.model = [model] if isinstance(model, nn.Module) else model
self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim

def state_dict(self) -> None:
def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
Expand Down
34 changes: 26 additions & 8 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,34 @@ def __init__(self):
help="How many train steps to run",
)
self.parser.add_argument(
"--training.data_parallel_degree",
"--training.data_parallel_replicate_degree",
type=int,
default=1,
help="""
The `data_parallel_replicate_degree` argument specifies the degree of
data parallelism for weight replication. When this value is greater
than 1, weights will be replicated across `data_parallel_replicate_degree`
ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
parallelism method used is DDP (Distributed Data Parallelism).
1 means disabled.""",
)
self.parser.add_argument(
"--training.data_parallel_shard_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
help="""
The `data_parallel_shard_degree` argument specifies the degree of data
parallelism for weight sharding. When this value is greater than 1, weights
will be sharded across `data_parallel_shard_degree` ranks. If
`data_parallel_replicate_degree` is also greater than 1, the parallelism
method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
parallelism method used is FSDP (Fully Sharded Data Parallelism).
-1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
only one of `data_parallel_replicate_degree` and `data_parallel_shard_degree`
can be negative.
1 means disabled.""",
)
self.parser.add_argument(
"--training.tensor_parallel_degree",
Expand Down Expand Up @@ -297,12 +321,6 @@ def __init__(self):
The default value will be the number of pipeline stages, if unspecified.
""",
)
self.parser.add_argument(
"--training.data_parallel_type",
type=str,
default="fsdp",
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
)
self.parser.add_argument(
"--experimental.enable_compiled_autograd",
action="store_true",
Expand Down
3 changes: 1 addition & 2 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_enabled
and parallel_dims.dp_type == "fsdp"
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
scaling_type_input = ScalingType(float8_config.scaling_type_input)
Expand Down
1 change: 0 additions & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class ModelArgs:
norm_eps: float = 1e-5
rope_theta: float = 10000

max_batch_size: int = 32
max_seq_len: int = 2048
# If `True`, then each transformer block init uses its layer ID, and if
# `False`, each uses the total number of transformer blocks
Expand Down
63 changes: 48 additions & 15 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,78 @@

@dataclass
class ParallelDims:
dp: int
dp_replicate: int
dp_shard: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str

def __post_init__(self):
self.dp_type = self.dp_type.lower()
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
assert dp >= 1, dp
dp_replicate, dp_shard, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.tp,
self.pp,
)
for d in (dp_replicate, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard >= 1
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
["pp", "dp_replicate", "dp_shard", "tp"],
strict=True,
):
if d > 1:
dims.append(d)
names.append(name)
if (name == "dp_replicate" and self.dp_shard == 1) or (
name == "dp_shard" and self.dp_replicate == 1
):
names.append("dp")
else:
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
# Create all the submesh here to ensure all required process groups are
# initialized
if self.dp_replicate > 1 and self.dp_shard > 1:
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
return mesh

@property
def dp_enabled(self):
return self.dp > 1
return self.dp_replicate > 1 or self.dp_shard > 1

@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1

@property
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def tp_enabled(self):
Expand Down
Loading

0 comments on commit af18855

Please sign in to comment.