Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 20, 2024
2 parents abe0a2b + 3f8f1a3 commit 7f2941b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 201 deletions.
12 changes: 12 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,18 @@ def __init__(self):
split via the provided split points, unflattened into an nn.Module,
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_microbatches",
type=int,
default=None,
help="""
How many microbatches to split the full training batch into when using pipeline parallelism.
The overall training batch size must be evenly divisible by the number of microbatches.
The default value will be the number of pipeline stages, if unspecified.
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=torch_dtype,
Expand Down
9 changes: 5 additions & 4 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ def pipeline_llama_manual(
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
# heuristically == PP dim but should be a config
microbatches = parallel_dims.pp
stage_idx = pp_rank # TODO support virtual stages
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
stage_idx = pp_rank
this_stage_layer_names = split_stage_fqns(
_llama_fqns(len(model.layers)),
job_config.experimental.pipeline_parallel_split_points,
Expand Down Expand Up @@ -303,7 +304,7 @@ def pipeline_llama_tracer(
# Create a pipeline representation from the model
pipe = pipeline(
model,
parallel_dims.pp,
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp,
example_args=_llama_trace_input(job_config, model_config),
split_spec=split_spec,
)
Expand Down
199 changes: 2 additions & 197 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,204 +3,9 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# from torch.distributed.pipelining import Schedule1F1B, ScheduleGPipe
from collections import defaultdict

from typing import Dict, List, Optional

import torch.distributed as dist
from torch.distributed.pipelining import ScheduleGPipe

# imports related to local copy of Schedule1F1B with local fix
from torch.distributed.pipelining.PipelineSchedule import (
PipelineScheduleSingle,
# sorted_batch_p2p,
)
from torch.profiler import record_function

from torch.distributed.pipelining import Schedule1F1B, ScheduleGPipe
from torchtitan.logging_utils import logger

# haven't landed these yet in core
def batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
desc_str = f"{desc}, " if desc else ""
logger.debug(f"batch_p2p {desc_str}{p2p_ops}") # noqa: G004
return dist.batch_isend_irecv(p2p_ops).pop()


def sorted_batch_p2p(
p2p_ops: List[dist.P2POp], desc: Optional[str] = None
) -> Dict[int, dist.Work]:
"""
Sorts the list of P2P ops by the peer rank, and then calls
batch_isend_irecv. Return a dictionary of works by peer rank. This function
helps us avoid hangs in case of skip connections.
"""
# Arrange p2p_ops by peer rank:
# int is the peer rank;
# List is the list of ops towards the peer
ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
work_by_peer: Dict[int, dist.Work] = {}
if len(p2p_ops) == 0:
return work_by_peer

# Classify the ops by peer rank
for op in p2p_ops:
ops_by_peer[op.peer].append(op)

# Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
for peer, ops in sorted(ops_by_peer.items()):
work_by_peer[peer] = batch_p2p(ops, desc=desc)

return work_by_peer


class Schedule1F1B(PipelineScheduleSingle):
"""
The 1F1B schedule.
Will perform one forward and one backward on the microbatches in steady state.
"""

def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the 1F1B schedule.
Args:
microbatches: list of microbatch args.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)

# forward for num_microbatches + backward for num_microbatches
total_ops = self._n_microbatches * 2

# Example, 4 GPUs, 8 microbatches
# Stage 0: 6 warmup, 2 1f1b, 6 cooldown
# Stage 1: 4 warmup, 4 1f1b, 4 cooldown
# Stage 2: 2 warmup, 6 1f1b, 2 cooldown
# Stage 3: 0 warmup, 8 1f1b, 0 cooldown
# fwd only
warmup_steps = min(
self._n_microbatches,
2 * (self._num_stages - self._stage.stage_index - 1),
)
# fwd + bwd
main_1f1b_steps = self._n_microbatches - warmup_steps
# bwd only
cooldown_steps = total_ops - (warmup_steps + (2 * main_1f1b_steps))
total_steps = warmup_steps + main_1f1b_steps + cooldown_steps
logger.debug(
f"Stage {self._stage.stage_index}: " # noqa: G004
f"Warmup steps: {warmup_steps}, "
f"Main 1F1B steps: {main_1f1b_steps}, "
f"Cooldown steps: {cooldown_steps}, "
f"Total steps: {total_steps}"
)

# Delay send waits
fwd_sends_to_wait: List[dist.Work] = []
bwd_sends_to_wait: List[dist.Work] = []

def is_forward_step(i):
assert i >= 0, i
return i < self._n_microbatches

def is_backward_step(i):
assert i < total_steps, i
return i >= warmup_steps and self._has_backward

def is_1f1b_step(i):
return is_forward_step(i) and is_backward_step(i)

def is_warmup_step(i):
return is_forward_step(i) and not is_backward_step(i)

def is_cooldown_step(i):
return not is_forward_step(i) and is_backward_step(i)

def should_coalesce_fwd_send_bwd_recv(fwd_send_i):
return (
is_1f1b_step(fwd_send_i)
or (is_warmup_step(fwd_send_i) and is_cooldown_step(fwd_send_i + 1))
or (
fwd_send_i >= 1
and is_warmup_step(fwd_send_i - 1)
and is_cooldown_step(fwd_send_i)
)
)

def should_coalesce_bwd_send_fwd_recv(bwd_send_i):
# The backward send to prev stage should be coalesced with the fwd recv from the previous stage
return bwd_send_i >= warmup_steps and is_1f1b_step(bwd_send_i + 1)

# bwd chunk counter
bwd_mb_index = 0
self._stage._configure_data_parallel_mode(last_backward=False)
for i in range(total_steps):
if is_forward_step(i):
with record_function(f"Forward {i}"):
ops = self._stage.get_fwd_recv_ops()
desc = "fwd_recv"
if should_coalesce_bwd_send_fwd_recv(i - 1):
desc += "_bwd_send"
ops.extend(self._stage.get_bwd_send_ops())

works = sorted_batch_p2p(ops, desc=desc)
for work in works.values():
work.wait()

output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]

if not should_coalesce_fwd_send_bwd_recv(i):
ops = self._stage.get_fwd_send_ops()
works = sorted_batch_p2p(ops, desc="fwd_send")
fwd_sends_to_wait.extend(works.values())

self._maybe_compute_loss(self._stage, output, target_mbs, i)

if is_backward_step(i):
self._stage._configure_data_parallel_mode(
last_backward=(i == total_steps - 1)
)
with record_function(f"Backward {bwd_mb_index}"):
ops = self._stage.get_bwd_recv_ops()
desc = "bwd_recv"
if should_coalesce_fwd_send_bwd_recv(i):
ops.extend(self._stage.get_fwd_send_ops())
desc += "_fwd_send"

works = sorted_batch_p2p(ops, desc=desc)
for work in works.values():
work.wait()

loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(loss=loss)

if not should_coalesce_bwd_send_fwd_recv(i):
# see Note: coalesced bwd-send/fwd-recv
ops = self._stage.get_bwd_send_ops()
works = sorted_batch_p2p(ops, desc="bwd_send")
bwd_sends_to_wait.extend(works.values())

bwd_mb_index += 1

# Wait for all forward sends to finish
for work in fwd_sends_to_wait:
work.wait()

# Wait for all backward sends to finish
for work in bwd_sends_to_wait:
work.wait()

# Return losses if there is a container passed in
self._update_losses(self._stage, losses)


def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn):
if job_config.experimental.pipeline_parallel_schedule == "1f1b":
Expand All @@ -216,7 +21,7 @@ def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn):
)
return schedule_class(
stage,
n_microbatches=parallel_dims.pp,
n_microbatches=stage.chunks,
loss_fn=loss_fn,
)

Expand Down

0 comments on commit 7f2941b

Please sign in to comment.