Skip to content

Commit

Permalink
[torchtitan][debug] integrated CommDebugMode into TorchTitan
Browse files Browse the repository at this point in the history
ghstack-source-id: fbbc6f0257396b21eea0e40939c832a7afa3490f
Pull Request resolved: #480
  • Loading branch information
sinhaanshul committed Jul 24, 2024
1 parent 0ee573c commit 8c6daf9
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
26 changes: 26 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,32 @@ def __init__(self):
""",
)

# commdebugmode configs
self.parser.add_argument(
"--comm_debug.enable_comm_debug_mode",
default=False,
action="store_true",
help="Whether to enable CommDebugMode, should be used only on first step",
)
self.parser.add_argument(
"--comm_debug.dump_file",
type=str,
default="torchtitan_comm_debug_dump.txt",
help="Which file to dump CommDebugMode's output to",
)
self.parser.add_argument(
"--comm_debug.dump_json",
type=str,
default="torchtitan_comm_debug_log.json",
help="Which file to dump CommDebugMode's json to",
)
self.parser.add_argument(
"--comm_debug.noise_level",
type=int,
default=2,
help="Sets noise level for CommDebugMode's output, controls how much info is displayed",
)

# communications library settings
self.parser.add_argument(
"--comm.init_timeout_seconds",
Expand Down
48 changes: 42 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import gc
import os
import time

from dataclasses import dataclass, field
from datetime import timedelta
from io import BytesIO
Expand All @@ -20,10 +19,10 @@
import torch
import torch.nn.functional as F
from torch.distributed import destroy_process_group
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel

from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
Expand Down Expand Up @@ -138,18 +137,26 @@ def zero_grad(self):
return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
def get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
enable_comm_debug_mode: bool,
):
@contextlib.contextmanager
def context():
context_managers = {}
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(loss_parallel())
if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)
if enable_comm_debug_mode:
comm_mode = stack.enter_context(CommDebugMode())
context_managers["comm_mode"] = comm_mode

yield
yield context_managers

return context

Expand Down Expand Up @@ -214,6 +221,7 @@ def main(job_config: JobConfig):
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
job_config.comm_debug.enable_comm_debug_mode,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
Expand Down Expand Up @@ -381,7 +389,7 @@ def loss_fn(pred, labels):
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
with train_context() as train_contexts:
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -390,6 +398,20 @@ def loss_fn(pred, labels):
else:
pp_schedule.step()

if (
job_config.comm_debug.enable_comm_debug_mode
and train_state.step == 1
):
comm_mode = train_contexts["comm_mode"]
comm_mode.log_comm_debug_tracing_table_to_file(
file_name=job_config.comm_debug.dump_file,
noise_level=job_config.comm_debug.noise_level,
)
comm_mode.generate_json_dump(
file_name=job_config.comm_debug.dump_json,
noise_level=job_config.comm_debug.noise_level,
)

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
Expand All @@ -398,14 +420,28 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with train_context():
with train_context() as train_contexts:
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

if (
job_config.comm_debug.enable_comm_debug_mode
and train_state.step == 1
):
comm_mode = train_contexts["comm_mode"]
comm_mode.log_comm_debug_tracing_table_to_file(
file_name=job_config.comm_debug.dump_file,
noise_level=job_config.comm_debug.noise_level,
)
comm_mode.generate_json_dump(
file_name=job_config.comm_debug.dump_json,
noise_level=job_config.comm_debug.noise_level,
)

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
Expand Down

0 comments on commit 8c6daf9

Please sign in to comment.