Skip to content

Commit

Permalink
[torchtitan][debug] integrated CommDebugMode into TorchTitan
Browse files Browse the repository at this point in the history
ghstack-source-id: ca3a9f5983a86dad0b867a8ec92e0e878e7784d5
Pull Request resolved: #480
  • Loading branch information
sinhaanshul committed Jul 24, 2024
1 parent 0f70507 commit 85c27f0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
26 changes: 26 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,32 @@ def __init__(self):
""",
)

# commdebugmode configs
self.parser.add_argument(
"--comm_debug.enable_comm_debug_mode",
default=False,
action="store_true",
help="Whether to enable CommDebugMode, should be used only on first step",
)
self.parser.add_argument(
"--comm_debug.dump_file",
type=str,
default="torchtitan_comm_debug_dump.txt",
help="Which file to dump CommDebugMode's output to",
)
self.parser.add_argument(
"--comm_debug.dump_json",
type=str,
default="torchtitan_comm_debug_log.json",
help="Which file to dump CommDebugMode's json to",
)
self.parser.add_argument(
"--comm_debug.noise_level",
type=int,
default=2,
help="Sets noise level for CommDebugMode's output, controls how much info is displayed",
)

# communications library settings
self.parser.add_argument(
"--comm.init_timeout_seconds",
Expand Down
37 changes: 32 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel

from torch.distributed._tensor.debug import CommDebugMode
from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
Expand Down Expand Up @@ -138,18 +138,22 @@ def zero_grad(self):
return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool, enable_comm_debug_mode: bool):
@contextlib.contextmanager
def context():
context_managers = {}
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(loss_parallel())
if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)
if enable_comm_debug_mode:
comm_mode = stack.enter_context(CommDebugMode())
context_managers["comm_mode"] = comm_mode

yield
yield context_managers

return context

Expand Down Expand Up @@ -214,6 +218,7 @@ def main(job_config: JobConfig):
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
job_config.comm_debug.enable_comm_debug_mode,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
Expand Down Expand Up @@ -381,7 +386,7 @@ def loss_fn(pred, labels):
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
with train_context() as tc:
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -390,6 +395,17 @@ def loss_fn(pred, labels):
else:
pp_schedule.step()

if job_config.comm_debug.enable_comm_debug_mode and train_state.step == 1:
comm_mode = tc["comm_mode"]
comm_mode.log_comm_debug_tracing_table_to_file(
file_name=job_config.comm_debug.dump_file,
noise_level=job_config.comm_debug.noise_level
)
comm_mode.generate_json_dump(
file_name=job_config.comm_debug.dump_json,
noise_level=job_config.comm_debug.noise_level
)

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
Expand All @@ -398,14 +414,25 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with train_context():
with train_context() as tc:
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

if job_config.comm_debug.enable_comm_debug_mode and train_state.step == 1:
comm_mode = tc["comm_mode"]
comm_mode.log_comm_debug_tracing_table_to_file(
file_name=job_config.comm_debug.dump_file,
noise_level=job_config.comm_debug.noise_level
)
comm_mode.generate_json_dump(
file_name=job_config.comm_debug.dump_json,
noise_level=job_config.comm_debug.noise_level
)

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
Expand Down

0 comments on commit 85c27f0

Please sign in to comment.