From c44cca02d016c3c362b7efef0a95ac5485623a16 Mon Sep 17 00:00:00 2001 From: Hugo <6937752+fduwjj@users.noreply.github.com> Date: Mon, 5 Aug 2024 12:51:07 -0700 Subject: [PATCH 1/2] [EZ][405B] Use scientific notation for 405B model lr (#504) As title, use `8e-5` rather than `0.8e-4`. --- train_configs/llama3_405b.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index fb250642..b7f78dc2 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -23,7 +23,7 @@ tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] name = "AdamW" -lr = 0.8e-4 +lr = 8e-5 [training] batch_size = 2 From 8849580d4d6f57a8cb3416aed4ed6f6ada3e5ab5 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sat, 3 Aug 2024 21:54:36 -0700 Subject: [PATCH 2/2] [BE][4/n] split pipeline_llama into a separate file ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/499 --- README.md | 6 +- estimation.py | 8 +- torchtitan/{float8_linear.py => float8.py} | 10 +- torchtitan/metrics.py | 6 +- torchtitan/parallelisms/__init__.py | 67 +-- torchtitan/parallelisms/parallel_dims.py | 70 +++ torchtitan/parallelisms/parallelize_llama.py | 537 ++++++------------- torchtitan/parallelisms/pipeline_llama.py | 221 ++++++++ train.py | 11 +- 9 files changed, 474 insertions(+), 462 deletions(-) rename torchtitan/{float8_linear.py => float8.py} (95%) create mode 100644 torchtitan/parallelisms/parallel_dims.py create mode 100644 torchtitan/parallelisms/pipeline_llama.py diff --git a/README.md b/README.md index 56785112..e762a492 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,10 @@ Our guiding principles when building `torchtitan`: You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first: * [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code -* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model +* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model +* [torchtitan/parallelisms/pipeline_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model * [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints +* [torchtitan/float8.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py) - utils for applying Float8 techniques * [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants) ## Pre-Release Updates: @@ -48,7 +50,7 @@ We report our [Performance](docs/performance.md) verified on 64 A100 GPUs ### Coming soon 1. Async checkpointing -2. FP8 support +2. Float8 support 3. Context Parallel 4. 3D Pipeline Parallel 5. `torch.compile` support diff --git a/estimation.py b/estimation.py index acf867d5..70fb66cb 100644 --- a/estimation.py +++ b/estimation.py @@ -16,7 +16,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer -from torchtitan.float8_linear import Float8Handler +from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers @@ -124,9 +124,9 @@ def loss_fn(pred, labels): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) - # a no-op hander if fp8 is not enabled + # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) - # swap to Float8Linear base on fp8 config + # swap to Float8Linear based on float8 configs float8_handler.convert_to_float8_training(whole_model) # apply PT-D DP/TP parallelisms and activation checkpointing @@ -190,7 +190,7 @@ def loss_fn(pred, labels): lr_schedulers.step() # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model) + float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) optimizers.zero_grad() print(f"Peak Memory at iter: {iter_idx}") fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) diff --git a/torchtitan/float8_linear.py b/torchtitan/float8.py similarity index 95% rename from torchtitan/float8_linear.py rename to torchtitan/float8.py index 494b6046..4dc7122b 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8.py @@ -21,7 +21,7 @@ from torchtitan.parallelisms import ParallelDims -def is_sm90_or_later(): +def _is_sm90_or_later(): # Float8 is only supported on H100+ GPUs return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -33,7 +33,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): float8_config = job_config.float8 if not float8_config.enable_float8_linear: return - if not is_sm90_or_later(): + if not _is_sm90_or_later(): logger.warning( "Failed to swap to Float8Linear because SM90 or later is not available", ) @@ -42,7 +42,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType except ImportError as e: raise ImportError( - "torchao is not installed. Please install it to use fp8 linear layers." + "torchao is not installed. Please install it to use float8 linear layers." ) from e # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear @@ -64,7 +64,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = True - # for precompute_fp8_dynamic_scale_for_fsdp + # for precompute_float8_dynamic_scale_for_fsdp self.precompute_scale = ( enable_fsdp_float8_all_gather and float8_config.precompute_float8_dynamic_scale_for_fsdp @@ -103,7 +103,7 @@ def convert_to_float8_training(self, model: nn.Module): f"{self.config.enable_fsdp_float8_all_gather}" ) - def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module): + def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module): if not self.enabled: return diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index f86ccc98..39ab8a07 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -127,7 +127,7 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int: def build_metric_logger( - config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None + job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None ): """ parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'. @@ -135,8 +135,8 @@ def build_metric_logger( intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline parallelism is enabled, without forcing logging from all ranks to capture loss information. """ - dump_dir = config.job.dump_folder - tb_config = config.metrics + dump_dir = job_config.job.dump_folder + tb_config = job_config.metrics save_tb_folder = tb_config.save_tb_folder # since we don't have run id, use current minute as the identifier datetime_str = datetime.now().strftime("%Y%m%d-%H%M") diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 7188474d..dc06d572 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -4,12 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from functools import cached_property -from torch.distributed.device_mesh import init_device_mesh -from torchtitan.logging import logger -from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama +from torchtitan.parallelisms.parallel_dims import ParallelDims +from torchtitan.parallelisms.parallelize_llama import parallelize_llama +from torchtitan.parallelisms.pipeline_llama import pipeline_llama from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule @@ -28,62 +26,3 @@ "llama2": pipeline_llama, "llama3": pipeline_llama, } - - -@dataclass -class ParallelDims: - dp: int - tp: int - pp: int - world_size: int - enable_loss_parallel: bool - dp_type: str - - def __post_init__(self): - self.dp_type = self.dp_type.lower() - self._validate() - - def _validate(self): - dp, tp, pp = self.dp, self.tp, self.pp - if dp == -1: - self.dp = dp = self.world_size // (tp * pp) - assert dp >= 1, dp - assert tp >= 1, tp - assert pp >= 1, pp - assert ( - dp * tp * pp == self.world_size - ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" - assert self.dp_type in ("fsdp", "ddp") - - def build_mesh(self, device_type): - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True - ): - if d > 1: - dims.append(d) - names.append(name) - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - names = tuple(names) - return init_device_mesh(device_type, dims, mesh_dim_names=names) - - @property - def dp_enabled(self): - return self.dp > 1 - - @property - def tp_enabled(self): - return self.tp > 1 - - @property - def pp_enabled(self): - return self.pp > 1 - - @property - def loss_parallel_enabled(self): - return self.tp > 1 and self.enable_loss_parallel - - @cached_property - def model_parallel_size(self): - return self.tp * self.pp diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py new file mode 100644 index 00000000..22c114ed --- /dev/null +++ b/torchtitan/parallelisms/parallel_dims.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from functools import cached_property + +from torch.distributed.device_mesh import init_device_mesh +from torchtitan.logging import logger + + +@dataclass +class ParallelDims: + dp: int + tp: int + pp: int + world_size: int + enable_loss_parallel: bool + dp_type: str + + def __post_init__(self): + self.dp_type = self.dp_type.lower() + self._validate() + + def _validate(self): + dp, tp, pp = self.dp, self.tp, self.pp + if dp == -1: + self.dp = dp = self.world_size // (tp * pp) + assert dp >= 1, dp + assert tp >= 1, tp + assert pp >= 1, pp + assert ( + dp * tp * pp == self.world_size + ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + assert self.dp_type in ("fsdp", "ddp") + + def build_mesh(self, device_type): + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True + ): + if d > 1: + dims.append(d) + names.append(name) + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + names = tuple(names) + return init_device_mesh(device_type, dims, mesh_dim_names=names) + + @property + def dp_enabled(self): + return self.dp > 1 + + @property + def tp_enabled(self): + return self.tp > 1 + + @property + def pp_enabled(self): + return self.pp > 1 + + @property + def loss_parallel_enabled(self): + return self.tp > 1 and self.enable_loss_parallel + + @cached_property + def model_parallel_size(self): + return self.tp * self.pp diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index a4b69344..03540552 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -4,25 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# This file applies the PT-D parallelisms and various training techniques (e.g. -# activation checkpointing and compile) to the Llama model. +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. -import copy from collections import defaultdict -from typing import Tuple, TYPE_CHECKING, Union import torch import torch.nn as nn from torch.distributed import DeviceMesh - from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy - from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, ) -from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -33,307 +28,74 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger -from torchtitan.models.llama.model import ModelArgs -from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank - -if TYPE_CHECKING: - from torchtitan.parallelisms import ParallelDims - - -DeviceType = Union[int, str, torch.device] - -# for selective AC -no_recompute_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, -} - - -def checkpoint_wrapper(module: torch.nn.Module, ac_config): - valid_ac_modes = ("full", "selective") - if ac_config.mode not in valid_ac_modes: - raise ValueError( - f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" - ) - - if ac_config.mode == "full": - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) - - assert ac_config.mode == "selective", f"{ac_config.mode}" - use_op_sac = ac_config.selective_ac_option == "op" - use_layer_sac = ac_config.selective_ac_option.isdigit() - if not use_op_sac and not use_layer_sac: - raise ValueError( - f"Invalid selective AC option: {ac_config.selective_ac_option}. " - f"Valid options: 'op' or a positive int representing layer frequency" - ) - if use_op_sac: - from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, - ) - - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in no_recompute_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if to_save - else CheckpointPolicy.PREFER_RECOMPUTE - ) - - return _custom_policy - - def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) +from torchtitan.parallelisms.parallel_dims import ParallelDims - return ptd_checkpoint_wrapper( - module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, - ) - elif use_layer_sac: - # Checkpoint every `ac_freq` of the modules passed to this function - ac_freq = int(ac_config.selective_ac_option) - if ac_freq <= 0: - raise ValueError( - f"Selective layer AC expects a positive int as selective_ac_option but got {ac_freq}" - ) - ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) - ptd_checkpoint_wrapper._count += 1 - if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) - else: - return module - - -def get_tp_parallel_strategy_for_transformer_block( - enable_float8: bool, -) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: - """Get the parallel strategy for the transformer model. - This function handles the special case of using float8 with tensor parallelism. - """ - if enable_float8: - # TODO(vkuzo): once float8 configuration supports delayed - # scaling, add a check here to enforce supported float8 all-gather - # configurations - # TODO(vkuzo): add the items below to __init__.py of torchao.float8, - # and import from there - from torchao.float8.float8_tensor_parallel import ( - Float8ColwiseParallel, - Float8RowwiseParallel, - PrepareFloat8ModuleInput, - ) - - return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput - return RowwiseParallel, ColwiseParallel, PrepareModuleInput - - -def pipeline_llama( +def parallelize_llama( model: nn.Module, - pp_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, - device: DeviceType, - model_config: ModelArgs, -): - split_mode = job_config.experimental.pipeline_parallel_split_mode - valid_split_modes = ("manual", "tracer") - if split_mode not in valid_split_modes: - raise ValueError( - f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}" - ) - if split_mode == "manual": - return pipeline_llama_manual( - model, pp_mesh, parallel_dims, job_config, device, model_config - ) - elif split_mode == "tracer": - return pipeline_llama_tracer( - model, pp_mesh, parallel_dims, job_config, device, model_config - ) - - -def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"): - """Get meta tensors with the right input shapes used for tracing""" - tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) - tokens = torch.randint( - model_config.vocab_size, tokens_shape, dtype=torch.int64, device=device - ) - return (tokens,) - - -def _mixed_precision_dtype( - job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32 -) -> torch.dtype: - """Get the mixed precision dtype if FSDP is enabled, otherwise return the default""" - mp_arg = job_config.training.mixed_precision_param - return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default - - -def pipeline_llama_manual( - whole_model: nn.Module, - pp_mesh: DeviceMesh, - parallel_dims: "ParallelDims", + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, job_config: JobConfig, - device: DeviceType, - model_config: ModelArgs, ): """ - This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. - - It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. - The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD - parallelism. + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. """ - pp_rank = pp_mesh.get_local_rank() - pp_size = pp_mesh.size() - microbatches = ( - job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp - ) - splits = job_config.experimental.pipeline_parallel_split_points - - def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False): - model = copy.deepcopy(whole_model) - if not is_first: - model.tok_embeddings = None - - drop_layers = start_layer is not None - for name in list(model.layers.keys()): - # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) - if f"layers.{name}" == start_layer: - drop_layers = False - if f"layers.{name}" == stop_layer: - drop_layers = True - if drop_layers: - del model.layers[name] - - if not is_last: - model.norm = None - model.output = None - - # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and - # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the - # layers of the model that map to this stage, not the whole model. - mp_dtype = _mixed_precision_dtype(job_config, parallel_dims) - batch_size = job_config.training.batch_size - local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) - layers_io_shape = (batch_size, local_seq_len, model_config.dim) - output_layer_shape = ( - batch_size, - job_config.training.seq_len, - model_config.vocab_size, - ) - if is_first: - (input,) = _llama_trace_input(job_config, model_config, device=device) - else: - # later layers (assume all start w/ a transformer layer) - input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) - if is_last: - output = torch.rand(output_layer_shape, dtype=torch.float32, device=device) - else: - # earlier layers (assume all end in a transformer layer) - output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) - - model.to_empty(device=device) - stage = PipelineStage( + if parallel_dims.tp_enabled: + if ( + job_config.experimental.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") + model = apply_tp( model, - stage_idx, - num_stages, - device, - input_args=input.chunk(microbatches)[0], - output_args=output.chunk(microbatches)[0], - group=pp_mesh.get_group("pp"), - ) - return stage, model - - num_stages = len(splits) + 1 - stage_idx = pp_rank - - stages = [] - models = [] - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"): - start_layer = splits[stage_idx - 1] if stage_idx > 0 else None - stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None - stage, model_chunk = _build_stage( - stage_idx, - start_layer, - stop_layer, - is_first=stage_idx == 0, - is_last=stage_idx == num_stages - 1, - ) - logger.info( - f"PP rank {pp_rank} is building stage_idx {stage_idx}" - f" with start_layer {start_layer}, stop_layer {stop_layer}: model chunk \n{model_chunk}" + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=job_config.float8.enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, ) - stages.append(stage) - models.append(model_chunk) - return stages, models + if job_config.activation_checkpoint.mode != "none": + model = apply_ac(model, job_config.activation_checkpoint) -def pipeline_llama_tracer( - model: nn.Module, - pp_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, - device: DeviceType, - model_config: ModelArgs, -): - if job_config.model.norm_type == "fused_rmsnorm": - # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr - # invocation stride in strict mode from `if dy.stride(-1) != 1:` in - # fused_rmsnorm - raise NotImplementedError( - "fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm." - ) - if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32: - raise NotImplementedError( - "Pipeline tracer does not work with FSDP mixed precision yet. " - "To work around, set mixed_precision_param to float32." - ) + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.training.compile: + if job_config.model.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm is not compatible with torch.compile yet. " + "Please use rmsnorm or layernorm." + ) + model = apply_compile(model) - pp_rank = pp_mesh.get_local_rank() - pp_size = pp_mesh.size() - microbatches = ( - job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp - ) - (input,) = _llama_trace_input(job_config, model_config, device=device) - stage_idx = pp_rank - split_spec = { - layer_name: SplitPoint.BEGINNING - for layer_name in job_config.experimental.pipeline_parallel_split_points - } - num_stages = len(split_spec) + 1 - pipe = pipeline( - model, - mb_args=(input.chunk(microbatches)[0],), - split_spec=split_spec, - ) + if parallel_dims.dp_enabled: + if parallel_dims.dp_type == "fsdp": + dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh + assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - stages = [] - models = [] - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"): - models.append(pipe.get_stage_module(stage_idx)) - stages.append( - pipe.build_stage( - stage_idx, - device=device, - group=pp_mesh.get_group(), + model = apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[ + job_config.training.mixed_precision_reduce + ], + pp_enabled=parallel_dims.pp_enabled, ) - ) - return (stages, models) + else: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + model = apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) + + return model def apply_tp( @@ -367,11 +129,27 @@ def apply_tp( # Parallel styles used for transformer block linear weights and their # inputs may be different for float8 linears - ( - rowwise_parallel_weight, - colwise_parallel_weight, - prepare_module_input, - ) = get_tp_parallel_strategy_for_transformer_block(enable_float8) + if enable_float8: + # TODO(vkuzo): once float8 configuration supports delayed scaling, + # add a check here to enforce supported float8 all-gather configurations + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel @@ -384,18 +162,18 @@ def apply_tp( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), - "attention.wq": colwise_parallel_weight(), - "attention.wk": colwise_parallel_weight(), - "attention.wv": colwise_parallel_weight(), - "attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), - "feed_forward.w1": colwise_parallel_weight(), - "feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)), - "feed_forward.w3": colwise_parallel_weight(), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), } parallelize_module( @@ -404,35 +182,97 @@ def apply_tp( parallelize_plan=layer_plan, ) - # updates expressly for async tensor parallel if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group + # TODO: remove cache_size_limit adjustment after 2D compile is fixed torch._dynamo.config.cache_size_limit = 10000 - logger.info( - "Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP" - ) torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) - if not job_config.training.compile: - logger.warning( - "Async TP requires compilation...auto enabling compile = True for this job to resolve." - ) - job_config.training.compile = True - logger.info( - f"Applied {'Async ' if enable_async_tp else ''}" + f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" "Tensor Parallelism to the model" ) return model -def apply_ac(model: nn.Module, ac_config: JobConfig): +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, +} + + +def _apply_ac_to_transformer_block(module: nn.Module, ac_config): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, + ) + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac(model: nn.Module, ac_config): """Apply activation checkpointing to the model.""" for layer_id, transformer_block in model.layers.named_children(): - transformer_block = checkpoint_wrapper(transformer_block, ac_config) + transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") @@ -442,13 +282,12 @@ def apply_ac(model: nn.Module, ac_config: JobConfig): def apply_compile(model: nn.Module): """Apply torch.compile to each transformer block.""" - # the following flag can be used to to accelarate per-block compilation + # the following flag can be used to to accelarate per-TransformerBlock compilation # TODO(bdhirsh): turning it off because it's currently not working with 2D # TODO(anijain): remove it after it's enabled in pytorch by default # torch._dynamo.config.inline_inbuilt_nn_modules = True for layer_id, transformer_block in model.layers.named_children(): - # turn on per-transformer block compile after AC wrapping and before FSDP transformer_block = torch.compile(transformer_block, fullgraph=True) model.layers.register_module(layer_id, transformer_block) @@ -518,63 +357,3 @@ def apply_ddp( logger.info("Applied DDP to the model") return model - - -def parallelize_llama( - model: nn.Module, - world_mesh: DeviceMesh, - parallel_dims: "ParallelDims", - job_config: JobConfig, -): - """ - Apply tensor parallelism, activation checkpointing, torch.compile, and data - parallelism to the model. - - NOTE: The passed-in model preferably should be on meta device. Otherwise, - the model must fit on GPU or CPU memory. - """ - - if parallel_dims.tp_enabled: - model = apply_tp( - model, - world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, - enable_float8=job_config.float8.enable_float8_linear, - enable_async_tp=job_config.experimental.enable_async_tensor_parallel, - ) - - if job_config.activation_checkpoint.mode != "none": - model = apply_ac(model, job_config.activation_checkpoint) - - if job_config.training.compile: - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." - ) - model = apply_compile(model) - - if parallel_dims.dp_enabled: - if parallel_dims.dp_type == "fsdp": - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - - model = apply_fsdp( - model, - dp_mesh, - param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[ - job_config.training.mixed_precision_reduce - ], - pp_enabled=parallel_dims.pp_enabled, - ) - else: - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism.") - model = apply_ddp( - model, - world_mesh, - enable_compile=job_config.training.compile, - enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, - ) - - return model diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py new file mode 100644 index 00000000..fa093b6e --- /dev/null +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D pipeline parallelism to the Llama model. + +import copy +from typing import Union + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.logging import logger +from torchtitan.models.llama.model import ModelArgs +from torchtitan.parallelisms.parallel_dims import ParallelDims +from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank + + +DeviceType = Union[int, str, torch.device] + + +def pipeline_llama( + model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, +): + split_mode = job_config.experimental.pipeline_parallel_split_mode + valid_split_modes = ("manual", "tracer") + if split_mode not in valid_split_modes: + raise ValueError( + f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}" + ) + if split_mode == "manual": + return pipeline_llama_manual( + model, pp_mesh, parallel_dims, job_config, device, model_config + ) + elif split_mode == "tracer": + return pipeline_llama_tracer( + model, pp_mesh, parallel_dims, job_config, device, model_config + ) + + +def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"): + """Get meta tensors with the right input shapes used for tracing""" + tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) + tokens = torch.randint( + model_config.vocab_size, tokens_shape, dtype=torch.int64, device=device + ) + return (tokens,) + + +def _mixed_precision_dtype( + job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32 +) -> torch.dtype: + """Get the mixed precision dtype if FSDP is enabled, otherwise return the default""" + mp_arg = job_config.training.mixed_precision_param + return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default + + +def pipeline_llama_manual( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, +): + """ + This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. + + It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. + + The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD + parallelism. + """ + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + microbatches = ( + job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp + ) + splits = job_config.experimental.pipeline_parallel_split_points + + def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False): + model = copy.deepcopy(whole_model) + if not is_first: + model.tok_embeddings = None + + drop_layers = start_layer is not None + for name in list(model.layers.keys()): + # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) + if f"layers.{name}" == start_layer: + drop_layers = False + if f"layers.{name}" == stop_layer: + drop_layers = True + if drop_layers: + del model.layers[name] + + if not is_last: + model.norm = None + model.output = None + + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and + # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the + # layers of the model that map to this stage, not the whole model. + mp_dtype = _mixed_precision_dtype(job_config, parallel_dims) + batch_size = job_config.training.batch_size + local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) + layers_io_shape = (batch_size, local_seq_len, model_config.dim) + output_layer_shape = ( + batch_size, + job_config.training.seq_len, + model_config.vocab_size, + ) + if is_first: + (input,) = _llama_trace_input(job_config, model_config, device=device) + else: + # later layers (assume all start w/ a transformer layer) + input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) + + if is_last: + output = torch.rand(output_layer_shape, dtype=torch.float32, device=device) + else: + # earlier layers (assume all end in a transformer layer) + output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) + + model.to_empty(device=device) + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + input_args=input.chunk(microbatches)[0], + output_args=output.chunk(microbatches)[0], + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(splits) + 1 + stage_idx = pp_rank + + stages = [] + models = [] + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"): + start_layer = splits[stage_idx - 1] if stage_idx > 0 else None + stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None + stage, model_chunk = _build_stage( + stage_idx, + start_layer, + stop_layer, + is_first=stage_idx == 0, + is_last=stage_idx == num_stages - 1, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx}" + f" with start_layer {start_layer}, stop_layer {stop_layer}: model chunk \n{model_chunk}" + ) + stages.append(stage) + models.append(model_chunk) + return stages, models + + +def pipeline_llama_tracer( + model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, +): + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr + # invocation stride in strict mode from `if dy.stride(-1) != 1:` in + # fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm is not compatible with Pipeline Tracer yet. " + "Please use rmsnorm or layernorm." + ) + if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32: + raise NotImplementedError( + "Pipeline tracer does not work with FSDP mixed precision yet. " + "To work around, set mixed_precision_param to float32." + ) + + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + microbatches = ( + job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp + ) + (input,) = _llama_trace_input(job_config, model_config, device=device) + stage_idx = pp_rank + split_spec = { + layer_name: SplitPoint.BEGINNING + for layer_name in job_config.experimental.pipeline_parallel_split_points + } + num_stages = len(split_spec) + 1 + pipe = pipeline( + model, + mb_args=(input.chunk(microbatches)[0],), + split_spec=split_spec, + ) + + stages = [] + models = [] + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"): + models.append(pipe.get_stage_module(stage_idx)) + stages.append( + pipe.build_stage( + stage_idx, + device=device, + group=pp_mesh.get_group(), + ) + ) + return (stages, models) diff --git a/train.py b/train.py index 615ed4e3..5c62debf 100644 --- a/train.py +++ b/train.py @@ -10,12 +10,13 @@ from datetime import timedelta import torch -import torchtitan.utils as utils from torch.distributed.elastic.multiprocessing.errors import record + +from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_hf_data_loader, build_tokenizer -from torchtitan.float8_linear import Float8Handler +from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config @@ -116,9 +117,9 @@ def main(job_config: JobConfig): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) - # a no-op hander if fp8 is not enabled + # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) - # swap to Float8Linear base on fp8 config + # swap to Float8Linear based on float8 configs float8_handler.convert_to_float8_training(whole_model) # log model size @@ -315,7 +316,7 @@ def loss_fn(pred, labels): # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model) + float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) losses_since_last_log.append(loss)