diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 9a086830..e0f6d701 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -482,6 +482,32 @@ def __init__(self): """, ) + # commdebugmode configs + self.parser.add_argument( + "--comm_debug.enable_comm_debug_mode", + default=False, + action="store_true", + help="Whether to enable CommDebugMode, should be used only on first step", + ) + self.parser.add_argument( + "--comm_debug.dump_file", + type=str, + default="torchtitan_comm_debug_dump.txt", + help="Which file to dump CommDebugMode's output to", + ) + self.parser.add_argument( + "--comm_debug.dump_json", + type=str, + default="torchtitan_comm_debug_log.json", + help="Which file to dump CommDebugMode's json to", + ) + self.parser.add_argument( + "--comm_debug.noise_level", + type=int, + default=2, + help="Sets noise level for CommDebugMode's output, controls how much info is displayed", + ) + # communications library settings self.parser.add_argument( "--comm.init_timeout_seconds", diff --git a/train.py b/train.py index b7eee302..0494364e 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,6 @@ import gc import os import time - from dataclasses import dataclass, field from datetime import timedelta from io import BytesIO @@ -20,10 +19,10 @@ import torch import torch.nn.functional as F from torch.distributed import destroy_process_group +from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel - from torchtitan.checkpoint import CheckpointManager from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_hf_data_loader, create_tokenizer @@ -138,9 +137,14 @@ def zero_grad(self): return OptimizersContainer([_build_optimizer(model) for model in model_parts]) -def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): +def get_train_context( + enable_loss_parallel: bool, + enable_compiled_autograd: bool, + enable_comm_debug_mode: bool, +): @contextlib.contextmanager def context(): + context_managers = {} with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(loss_parallel()) @@ -148,8 +152,11 @@ def context(): stack.enter_context( torch._dynamo.utils.maybe_enable_compiled_autograd(True) ) + if enable_comm_debug_mode: + comm_mode = stack.enter_context(CommDebugMode()) + context_managers["comm_mode"] = comm_mode - yield + yield context_managers return context @@ -214,6 +221,7 @@ def main(job_config: JobConfig): train_context = get_train_context( parallel_dims.loss_parallel_enabled, job_config.experimental.enable_compiled_autograd, + job_config.comm_debug.enable_comm_debug_mode, ) # loss fn can be shared by pipeline-parallel or non-pp execution @@ -381,7 +389,7 @@ def loss_fn(pred, labels): # pipeline parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with train_context(): + with train_context() as train_contexts: if pp_mesh.get_local_rank() == 0: pp_schedule.step(input_ids) elif is_last_stage: @@ -390,6 +398,20 @@ def loss_fn(pred, labels): else: pp_schedule.step() + if ( + job_config.comm_debug.enable_comm_debug_mode + and train_state.step == 1 + ): + comm_mode = train_contexts["comm_mode"] + comm_mode.log_comm_debug_tracing_table_to_file( + file_name=job_config.comm_debug.dump_file, + noise_level=job_config.comm_debug.noise_level, + ) + comm_mode.generate_json_dump( + file_name=job_config.comm_debug.dump_json, + noise_level=job_config.comm_debug.noise_level, + ) + # accumulate losses across pipeline microbatches loss = ( torch.mean(torch.stack(losses)) @@ -398,7 +420,7 @@ def loss_fn(pred, labels): ) else: # Non-PP forward / backward - with train_context(): + with train_context() as train_contexts: pred = model(input_ids) loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size) @@ -406,6 +428,20 @@ def loss_fn(pred, labels): del pred loss.backward() + if ( + job_config.comm_debug.enable_comm_debug_mode + and train_state.step == 1 + ): + comm_mode = train_contexts["comm_mode"] + comm_mode.log_comm_debug_tracing_table_to_file( + file_name=job_config.comm_debug.dump_file, + noise_level=job_config.comm_debug.noise_level, + ) + comm_mode.generate_json_dump( + file_name=job_config.comm_debug.dump_json, + noise_level=job_config.comm_debug.noise_level, + ) + # clip gradients for model in model_parts: torch.nn.utils.clip_grad_norm_(