diff --git a/README.md b/README.md index 53a0d149..eb6aa76a 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # torchtitan -`torchtitan` is currently in a pre-release state and under extensive development. +`torchtitan` is currently in a pre-release state and under extensive development. Currently we showcase pre-training **Llama 3.1**, **Llama 3**, and **Llama 2** LLMs of various sizes from scratch. To use the latest features of `torchtitan`, we recommend latest PyTorch nightly. `torchtitan` is a proof-of-concept for Large-scale LLM training using native PyTorch. It is (and will continue to be) a repo to showcase PyTorch's latest distributed training features in a clean, minimal codebase. torchtitan is complementary to and not a replacement for any of the great large-scale LLM training codebases such as Megatron, Megablocks, LLM Foundry, Deepspeed, etc. Instead, we hope that the features showcased in torchtitan will be adopted by these codebases quickly. torchtitan is unlikely to ever grow a large community around it. @@ -26,34 +26,30 @@ You may want to see how the model is defined or how parallelism techniques are a * [torchtitan/parallelisms/pipeline_llama.py](torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model * [torchtitan/checkpoint.py](torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints * [torchtitan/float8.py](torchtitan/float8.py) - utils for applying Float8 techniques -* [torchtitan/models/llama/model.py](torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants) - -## Pre-Release Updates: -#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development. -Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly). +* [torchtitan/models/llama/model.py](torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama 2 and Llama 3 variants) ### Key features available -1. [FSDP2 with per param sharding](docs/fsdp.md) -2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including async TP) +1. [FSDP2](docs/fsdp.md) with per param sharding +2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including [async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)) 3. Selective layer and operator activation checkpointing 4. Distributed checkpointing (including async checkpointing) 5. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) -6. Loss, GPU memory, tokens-per-second, and MFU displayed and logged via TensorBoard -7. Learning rate scheduler, meta-init, optional Fused RMSNorm into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning -8. [Float8 support](docs/float8.md) +6. Loss, GPU memory, tokens-per-second, and MFU displayed and logged via [TensorBoard](#tensorboard) +7. Learning rate scheduler, meta-init, optional Fused RMSNorm +8. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md)) 9. `torch.compile` support -10. All options easily configured via [toml files](train_configs/) -11. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly +10. DDP and HSDP +11. All options easily configured via [toml files](train_configs/) +12. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning -We report our [Performance](docs/performance.md) verified on 64 A100 GPUs +We report our [Performance](docs/performance.md) verified on 64/128 GPUs. ### Coming soon -1. Context Parallel -2. Pipeline Parallel (and 3D parallellism) -3. HSDP +- Pipeline Parallel (and 3D parallellism) +- Context Parallel ## Installation @@ -74,10 +70,10 @@ Once you have confirmed access, you can run the following command to download th ```bash # Get your HF token from https://huggingface.co/settings/tokens -# llama3 or 3.1 tokenizer.model +# Llama 3 or 3.1 tokenizer.model python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=... -# llama2 tokenizer.model +# Llama 2 tokenizer.model python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-2-13b-hf --hf_token=... ``` diff --git a/estimation.py b/estimation.py index 13ccd4c1..f58907c6 100644 --- a/estimation.py +++ b/estimation.py @@ -64,12 +64,12 @@ def estimate_memory(job_config: JobConfig): job_config.experimental.enable_compiled_autograd = False parallel_dims = ParallelDims( - dp=job_config.training.data_parallel_degree, + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, - dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") diff --git a/test/test_fused_rms_norm.py b/test/test_fused_rms_norm.py index 9bd7e373..d5c353c2 100644 --- a/test/test_fused_rms_norm.py +++ b/test/test_fused_rms_norm.py @@ -11,7 +11,7 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test_runner.py b/test_runner.py index d7633e2e..859fd5ed 100755 --- a/test_runner.py +++ b/test_runner.py @@ -157,7 +157,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1f1b", - "--training.data_parallel_degree 1", + "--training.data_parallel_shard_degree 1", ], ], "PP 1D test 1f1b", @@ -172,7 +172,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule gpipe", - "--training.data_parallel_degree 1", + "--training.data_parallel_shard_degree 1", ], ], "PP 1D test gpipe", @@ -187,7 +187,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1f1b", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", ], ], "PP+DP 1f1b 2D test", @@ -201,7 +201,7 @@ def build_test_list(): "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule gpipe", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", ], ], "PP+DP gpipe 2D test", @@ -227,7 +227,7 @@ def build_test_list(): "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", ], [ @@ -235,7 +235,7 @@ def build_test_list(): "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", ], ], @@ -249,7 +249,7 @@ def build_test_list(): [ "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.4", - "--training.data_parallel_degree 2", + "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", "--float8.enable_float8_linear", "--float8.enable_fsdp_float8_all_gather", @@ -302,13 +302,37 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.data_parallel_type ddp", + "--training.data_parallel_shard_degree=1", + "--training.data_parallel_replicate_degree=4", ] ], "DDP", "ddp", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_shard_degree=2", + "--training.data_parallel_replicate_degree=2", + ] + ], + "HSDP", + "hsdp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--training.data_parallel_shard_degree=2", + "--training.data_parallel_replicate_degree=2", + "--training.tensor_parallel_degree=2", + ] + ], + "HSDP+TP", + "hsdp+tp", + ngpu=8, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index b71419c6..266f689c 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -84,7 +84,7 @@ class ModelWrapper(Stateful): def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None: self.model = [model] if isinstance(model, nn.Module) else model - def state_dict(self) -> None: + def state_dict(self) -> Dict[str, Any]: return { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() } @@ -107,7 +107,7 @@ def __init__( self.model = [model] if isinstance(model, nn.Module) else model self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim - def state_dict(self) -> None: + def state_dict(self) -> Dict[str, Any]: func = functools.partial( get_optimizer_state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True), diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3ba1d102..67c82d53 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -224,10 +224,34 @@ def __init__(self): help="How many train steps to run", ) self.parser.add_argument( - "--training.data_parallel_degree", + "--training.data_parallel_replicate_degree", + type=int, + default=1, + help=""" + The `data_parallel_replicate_degree` argument specifies the degree of + data parallelism for weight replication. When this value is greater + than 1, weights will be replicated across `data_parallel_replicate_degree` + ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism + method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the + parallelism method used is DDP (Distributed Data Parallelism). + 1 means disabled.""", + ) + self.parser.add_argument( + "--training.data_parallel_shard_degree", type=int, default=-1, - help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", + help=""" + The `data_parallel_shard_degree` argument specifies the degree of data + parallelism for weight sharding. When this value is greater than 1, weights + will be sharded across `data_parallel_shard_degree` ranks. If + `data_parallel_replicate_degree` is also greater than 1, the parallelism + method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the + parallelism method used is FSDP (Fully Sharded Data Parallelism). + + -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that + only one of `data_parallel_replicate_degree` and `data_parallel_shard_degree` + can be negative. + 1 means disabled.""", ) self.parser.add_argument( "--training.tensor_parallel_degree", @@ -297,12 +321,6 @@ def __init__(self): The default value will be the number of pipeline stages, if unspecified. """, ) - self.parser.add_argument( - "--training.data_parallel_type", - type=str, - default="fsdp", - help="Data parallelism type. TorchTitan currently supports FSDP and DDP.", - ) self.parser.add_argument( "--experimental.enable_compiled_autograd", action="store_true", diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 244fba47..4e5fa8f5 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -49,8 +49,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( - parallel_dims.dp_enabled - and parallel_dims.dp_type == "fsdp" + parallel_dims.dp_shard_enabled and float8_config.enable_fsdp_float8_all_gather ) scaling_type_input = ScalingType(float8_config.scaling_type_input) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 2060519a..7f102a80 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -29,7 +29,6 @@ class ModelArgs: norm_eps: float = 1e-5 rope_theta: float = 10000 - max_batch_size: int = 32 max_seq_len: int = 2048 # If `True`, then each transformer block init uses its layer ID, and if # `False`, each uses the total number of transformer blocks diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 22c114ed..2e2aacc7 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -13,45 +13,78 @@ @dataclass class ParallelDims: - dp: int + dp_replicate: int + dp_shard: int tp: int pp: int world_size: int enable_loss_parallel: bool - dp_type: str def __post_init__(self): - self.dp_type = self.dp_type.lower() self._validate() def _validate(self): - dp, tp, pp = self.dp, self.tp, self.pp - if dp == -1: - self.dp = dp = self.world_size // (tp * pp) - assert dp >= 1, dp + dp_replicate, dp_shard, tp, pp = ( + self.dp_replicate, + self.dp_shard, + self.tp, + self.pp, + ) + for d in (dp_replicate, tp, pp): + assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" + assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." + + dp = dp_replicate * dp_shard + if dp < 0: + dp = self.world_size // (tp * pp) + self.dp_shard = dp_shard = dp // dp_replicate + + assert dp_replicate >= 1 + assert dp_shard >= 1 assert tp >= 1, tp assert pp >= 1, pp - assert ( - dp * tp * pp == self.world_size - ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" - assert self.dp_type in ("fsdp", "ddp") + assert dp_replicate * dp_shard * tp * pp == self.world_size, ( + f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " + f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + ) def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True + [self.pp, self.dp_replicate, self.dp_shard, self.tp], + ["pp", "dp_replicate", "dp_shard", "tp"], + strict=True, ): if d > 1: dims.append(d) - names.append(name) + if (name == "dp_replicate" and self.dp_shard == 1) or ( + name == "dp_shard" and self.dp_replicate == 1 + ): + names.append("dp") + else: + names.append(name) + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) - return init_device_mesh(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + # Create all the submesh here to ensure all required process groups are + # initialized + if self.dp_replicate > 1 and self.dp_shard > 1: + mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") + return mesh @property def dp_enabled(self): - return self.dp > 1 + return self.dp_replicate > 1 or self.dp_shard > 1 + + @property + def dp_replicate_enabled(self): + return self.dp_replicate > 1 + + @property + def dp_shard_enabled(self): + return self.dp_shard > 1 @property def tp_enabled(self): diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index aa07f25f..fc26703d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -73,9 +73,11 @@ def parallelize_llama( apply_compile(model) if parallel_dims.dp_enabled: - if parallel_dims.dp_type == "fsdp": - dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names + if parallel_dims.dp_shard_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mesh = world_mesh["dp_replicate", "dp_shard"] + else: + dp_mesh = world_mesh["dp"] apply_fsdp( model, @@ -87,6 +89,10 @@ def parallelize_llama( tp_enabled=parallel_dims.tp_enabled, pp_enabled=parallel_dims.pp_enabled, ) + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") else: if world_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") @@ -322,8 +328,6 @@ def apply_fsdp( ) fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) - logger.info("Applied FSDP to the model") - def apply_ddp( model: nn.Module, diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index 7bce2fe6..7e12aea6 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -104,9 +104,11 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal model.norm = None model.output = None - # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and - # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the - # layers of the model that map to this stage, not the whole model. + # Note: these tensors are only here as metadata hints, so pipelining runtime knows what size buffer to allocate. + # these tensors should be on meta device, adn the model should also. It will be allocated on device after + # applying all other parallelisms. + + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can avoid specifying input/output shapes mp_dtype = _mixed_precision_dtype(job_config, parallel_dims) batch_size = job_config.training.batch_size local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) @@ -117,18 +119,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal model_config.vocab_size, ) if is_first: - (input,) = _llama_trace_input(job_config, model_config, device=device) + (input,) = _llama_trace_input(job_config, model_config, device="meta") else: # later layers (assume all start w/ a transformer layer) - input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) + input = torch.rand(layers_io_shape, dtype=mp_dtype, device="meta") if is_last: - output = torch.rand(output_layer_shape, dtype=torch.float32, device=device) + output = torch.rand(output_layer_shape, dtype=torch.float32, device="meta") else: # earlier layers (assume all end in a transformer layer) - output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) + output = torch.rand(layers_io_shape, dtype=mp_dtype, device="meta") - model.to_empty(device=device) stage = PipelineStage( model, stage_idx, diff --git a/train.py b/train.py index ffea00a9..d1973b6d 100644 --- a/train.py +++ b/train.py @@ -59,12 +59,12 @@ def main(job_config: JobConfig): # init distributed world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( - dp=job_config.training.data_parallel_degree, + dp_shard=job_config.training.data_parallel_shard_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, - dp_type=job_config.training.data_parallel_type, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index af547214..bb3cd353 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -35,7 +35,8 @@ seq_len = 2048 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 1 compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index df2f6bb3..3230b208 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -31,7 +31,8 @@ seq_len = 4096 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 1 compile = false dataset = "c4" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 354ebe11..e7c920c6 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -31,7 +31,8 @@ seq_len = 4096 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 8 # 8-way TP compile = false dataset = "c4" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index e2b0e78d..5ffaaeca 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -30,7 +30,8 @@ seq_len = 2048 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B compile = false dataset = "c4" diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index 1a83301f..c7723ef3 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -31,7 +31,8 @@ seq_len = 8192 warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 3000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 8 # 8-way TP compile = true dataset = "c4" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 470149a5..fb6d5f50 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -31,7 +31,8 @@ seq_len = 8192 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 8 # 8-way TP compile = false dataset = "c4" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 3d0c5160..e0c5bd03 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -31,7 +31,8 @@ seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 1000 -data_parallel_degree = -1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 tensor_parallel_degree = 1 compile = false dataset = "c4"