Skip to content

Commit

Permalink
Update on "[WIP] zero bubble"
Browse files Browse the repository at this point in the history
To run zb test: 
`python test_runner.py ./test_out --test pp_zb`

TODO:
- zero bubble when AC is turned off is failing when using multiple hosts:
```
File "/packages/torchtitan_additional_packages/torchtitan/torchtitan/parallelisms/pipelining/stage.py", line 668, in backward_weight_one_chunk
      dweights = self.dw_runner.pop(bwd_chunk_id)(
    File "/packages/torchtitan_additional_packages/torchtitan/torchtitan/parallelisms/pipelining/_backward.py", line 251, in stage_backward_weight
      dweight = all_dweights[grad_acc]
  KeyError: <AccumulateGrad object at 0x7ff490125b10>
```
  


[ghstack-poisoned]
  • Loading branch information
H-Huang committed Sep 26, 2024
2 parents ff4b1e6 + b346b92 commit d79c77f
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 113 deletions.
2 changes: 1 addition & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def __init__(self):
self.parser.add_argument(
"--comm.init_timeout_seconds",
type=int,
default=600, # 300 seconds is timing out?
default=700, # 300 seconds is timing out?
help="Timeout for communication operations, during initialization and first train step.",
)
self.parser.add_argument(
Expand Down
5 changes: 2 additions & 3 deletions torchtitan/parallelisms/pipelining/_IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def __init__(self, module, garbage_collect_values=True):
super().__init__(module, garbage_collect_values)
self.value_remap = {}

def run(self, *args, initial_env=None):
def run(self, *args, initial_env=None): # type: ignore[override]
self.value_remap = {}
return super().run(*args, initial_env=initial_env)

Expand Down Expand Up @@ -932,8 +932,7 @@ def move_param_to_callee(
if node.op == "get_attr":
# get_attr might get access deeper level attribute
fqn = scope + "." + node.target if scope else node.target
if fqn in unused_attributes: # used, remove it
unused_attributes.remove(fqn)
unused_attributes.discard(fqn)
for _name, _submod in _mod.named_children():
stack.append((scope + "." + _name if scope else _name, _submod))
# delete unused attributes
Expand Down
43 changes: 35 additions & 8 deletions torchtitan/parallelisms/pipelining/_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def stage_backward_input(
"""
compute the gradients for only the stage inputs with respect to the stage outputs
"""

stage_output_grad_fns: List[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs))
)
Expand Down Expand Up @@ -225,6 +224,13 @@ def stage_backward_weight(
)
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])

# Break a reference cycle caused inside stage_backward_input->get_hook->hook
# The summarized cycle is:
# `hook` -> cell -> param_group -> intermediates -> `hook`
# becuase we install the hook function onto each of the intermediate autograd nodes.
# We need to keep intermediates alive up until backward_weight, but we can free it now.
del param_group["intermediates"]

assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
Expand Down Expand Up @@ -269,10 +275,15 @@ def stage_backward(
try:
# stage_output may be a composite datatype like dict. Extract all individual
# tensor values here
stage_output_tensors = []
output_grad_tensors = []

def extract_tensors_with_grads(output_val, grad_val):
stage_output_tensors: List[torch.Tensor] = []
output_grad_tensors: List[Optional[torch.Tensor]] = []

def extract_tensors_with_grads(
output_val,
grad_val,
# Don't delete me- see [Note: ref cycle]
extract_tensors_with_grads,
):
if isinstance(output_val, torch.Tensor):
if not output_val.requires_grad and output_val.grad_fn is None:
return
Expand All @@ -289,19 +300,35 @@ def extract_tensors_with_grads(output_val, grad_val):
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
assert len(output_val) == len(grad_val)
for ov, gv in zip(output_val, grad_val):
extract_tensors_with_grads(ov, gv)
extract_tensors_with_grads(
ov,
gv,
extract_tensors_with_grads,
)
elif isinstance(output_val, dict):
if grad_val is None:
return
assert isinstance(grad_val, dict)
assert set(output_val.keys()) == set(grad_val.keys())
for k in output_val.keys():
extract_tensors_with_grads(output_val[k], grad_val[k])
extract_tensors_with_grads(
output_val[k], grad_val[k], extract_tensors_with_grads
)
else:
# Output is a non-tensor type; just ignore it
pass

extract_tensors_with_grads(stage_output, output_grads)
# Note: ref cycle
# break a ref cycle that would keep tensors alive until GC runs
# 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward
# and used in extract_tensors_with_grads
# 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors,
# and to itself (extract_tensors_with_grads) since it makes a recursive call
# 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad
# fix -> explictly pass in the ref to the fn, so there is no gc cycle anymore
extract_tensors_with_grads(
stage_output, output_grads, extract_tensors_with_grads
)

torch.autograd.backward(
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
Expand Down
52 changes: 35 additions & 17 deletions torchtitan/parallelisms/pipelining/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,7 @@ def __init__(
self._num_stages = stage.num_stages
# Set the same has_backward flag for stage object
self._stage.has_backward = self._has_backward

# TODO: later replace this with lazy shape inference during forward
# Prepare forward send/recv infrastructure for stage
stage._prepare_forward_infra(n_microbatches)
if self._has_backward:
stage._prepare_backward_infra(n_microbatches)
self._stage_initialized = False

def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
"""
Expand Down Expand Up @@ -643,6 +638,9 @@ def _step_microbatches(
)

arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
if not self._stage_initialized:
self._stage._prepare_forward_infra(self._n_microbatches)
self._stage_initialized = True

# Delay send waits
fwd_sends_to_wait: List[dist.Work] = []
Expand Down Expand Up @@ -692,6 +690,12 @@ def _step_microbatches(
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)

if not self._stage_initialized:
self._stage._prepare_forward_infra(self._n_microbatches)
if self._has_backward:
self._stage._prepare_backward_infra(self._n_microbatches)
self._stage_initialized = True

# Delay send waits
fwd_sends_to_wait: List[dist.Work] = []

Expand Down Expand Up @@ -772,6 +776,12 @@ def _step_microbatches(
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)

if not self._stage_initialized:
self._stage._prepare_forward_infra(self._n_microbatches)
if self._has_backward:
self._stage._prepare_backward_infra(self._n_microbatches)
self._stage_initialized = True

# Last stage has 1 warmup, second-to-last 2 warmups, ...
# first stage `num_stages` warmups
warmup_chunks = min(
Expand Down Expand Up @@ -1076,22 +1086,16 @@ def __init__(
# Set the same has_backward flag for stage object
for stage in self._stages:
stage.has_backward = self._has_backward
self._stages_initialized = False

self._should_compute_loss = (
lambda stage: stage.is_last and self._loss_fn is not None
)
# avoid putting a reference to 'self' inside the lambda, it creates a ref cycle
has_loss: bool = self._loss_fn is not None
self._should_compute_loss = lambda stage: stage.is_last and has_loss

# This will be set during init of derived schedules
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
self.use_full_backward = use_full_backward

# TODO: later replace this with lazy shape inference during forward
# Prepare forward send/recv infrastructure for stage
for stage in self._stages:
stage._prepare_forward_infra(n_microbatches)
if self._has_backward:
stage._prepare_backward_infra(n_microbatches)

def _dump_csv(self, filename):
"""Dump a CSV representation of the schedule into a file with the provided filename."""
with open(filename, "w", newline="") as csvfile:
Expand Down Expand Up @@ -1186,7 +1190,6 @@ def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
target: target for the loss function.
losses: a list to store the losses for each microbatch.
"""

# Clean per iteration
for stage in self._stages:
stage.clear_runtime_states()
Expand Down Expand Up @@ -1225,6 +1228,14 @@ def _step_microbatches(
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)

if not self._stages_initialized:
for stage in self._stages:
# TODO: why do i pass args/kwargs here? its not used?
stage._prepare_forward_infra(self._n_microbatches)
if self._has_backward:
stage._prepare_backward_infra(self._n_microbatches)
self._stages_initialized = True

# Based on the plan in Step 1 created in __init__:
# 2. Perform communication based on the pipeline_order
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
Expand Down Expand Up @@ -1451,6 +1462,13 @@ def _step_microbatches(
not support models with skip connections.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
if not self._stages_initialized:
for stage in self._stages:
# TODO: why do i pass args/kwargs here? its not used?
stage._prepare_forward_infra(self._n_microbatches)
if self._has_backward:
stage._prepare_backward_infra(self._n_microbatches)
self._stages_initialized = True

# Based on the plan in Step 1 created in __init__:
# 2. Perform communication based on the pipeline_order
Expand Down
88 changes: 52 additions & 36 deletions torchtitan/parallelisms/pipelining/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard
from torch.fx.node import map_aggregate
from torch.nn.parallel import DistributedDataParallel
from torch.utils._pytree import tree_map_only

from ._backward import stage_backward, stage_backward_input, stage_backward_weight
from ._debug import map_debug_info
Expand Down Expand Up @@ -251,7 +252,10 @@ def map_recv_to_send(a):
return grad_send_info

@abstractmethod
def _prepare_forward_infra(self, num_microbatches: int):
def _prepare_forward_infra(
self,
num_microbatches: int,
):
raise NotImplementedError

def _prepare_backward_infra(self, num_microbatches: int):
Expand Down Expand Up @@ -680,14 +684,6 @@ def backward_one_chunk(
if isinstance(bwd_kwargs["stage_output"], torch.Tensor):
bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],)

# if self.stage_index == 0:
# for inp in bwd_kwargs["input_values"]:
# if not inp.requires_grad:
# inp.requires_grad_(True)

# for inp in bwd_kwargs["input_values"]:
# print(inp.requires_grad)

grads_input, param_groups = self.backward_maybe_with_nosync(
"input", bwd_kwargs
)
Expand Down Expand Up @@ -853,7 +849,10 @@ def _move_submod_to_device(self):
else:
self.submod.to(self.device)

def _prepare_forward_infra(self, num_microbatches: int):
def _prepare_forward_infra(
self,
num_microbatches: int,
):
"""
Create send/recv infrastructures for activations (during forward)
"""
Expand Down Expand Up @@ -1266,24 +1265,27 @@ def __init__(
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
):
super().__init__(submodule, stage_index, num_stages, device, group, dw_builder)
self.submod.to(self.device)
# When we materialize the model partition on cuda, we call reset_parameters() if it is available
self.inputs: List[torch.Tensor] = []
self.outputs: List[torch.Tensor] = []

self.inputs = _create_empty_tensors(input_args, device)
self.inputs: Optional[List[torch.Tensor]] = None

# Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) becuase it
# might be breaking for existing users.
self.inputs_meta = (
(input_args,) if isinstance(input_args, torch.Tensor) else input_args
)
if output_args is None:
logger.info("output_args not provided, performing forward using input_args")
self.outputs = self.submod(*self.inputs)
# create buffers for the output so that the data is in the correct
# shape in order to use in p2p op (send)
self.outputs = _create_empty_tensors(self.outputs, device)
else:
self.outputs = _create_empty_tensors(output_args, device)

self._configure_outputs_meta(tuple(self.outputs))

try:
output_args = submodule(*self.inputs_meta)
output_args = tree_map_only(
torch.Tensor, lambda x: x.to("meta"), output_args
)
except Exception as e:
raise RuntimeError(
"Failed to perform pipeline shape inference- are your inputs on the same device as your module?"
) from e
assert output_args is not None # for mypy
self._configure_outputs_meta(
(output_args,) if isinstance(output_args, torch.Tensor) else output_args
)
# these are the buffers used in backwards send/recv, they are allocated later
self.outputs_grad: List[torch.Tensor] = []

Expand All @@ -1300,11 +1302,16 @@ def stage_global_rank(peer_rank):
logger.debug(
f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004
f"{self.is_last=}, {self.num_stages=}, "
f"inputs: {[inp.shape for inp in self.inputs]}, "
f"output: {[output.shape for output in self.outputs]}"
f"inputs: {[inp.shape for inp in self.inputs_meta]}, "
f"output: {[output.shape for output in self.get_outputs_meta()]}"
)

def _prepare_forward_infra(self, num_microbatches: int) -> None:
def _prepare_forward_infra(
self,
num_microbatches: int,
) -> None:
# TODO move self.device to an argument from step API (from its input tensors)?

# Receive info during forward
# TODO: create args_recv_info lazily? (same needed for PipelineStage)
for chunk_id in range(num_microbatches):
Expand All @@ -1318,20 +1325,23 @@ def _prepare_forward_infra(self, num_microbatches: int) -> None:
self.stage_index - 1,
_make_tensor_from_meta(inp, self.device),
)
for inp in self.inputs
for inp in self.inputs_meta
]
)

self.args_recv_info[chunk_id] = recv_infos
else:
self.args_recv_info[chunk_id] = tuple(
[_RootArgPlaceholder(i) for i in self.inputs]
[_RootArgPlaceholder(i) for i in self.inputs_meta]
)

# Send info during forward for each activation
# only need the rank that is being sent to
self.act_send_info: Dict[int, List] = {}
for idx in range(len(self.outputs)):

# TODO: we didn't require output args at __init__ before, but now we do. enforce it. until we enable lazy-init
# get_outputs_meta will assert for us
for idx in range(len(self.get_outputs_meta())):
# We assume we always send to stage + 1
if not self.is_last:
self.act_send_info[idx] = [self.stage_index + 1]
Expand All @@ -1351,7 +1361,9 @@ def _create_grad_recv_info(
_RecvInfo(
f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}",
dst_list[0],
_make_tensor_from_meta(self.outputs[idx], self.device),
_make_tensor_from_meta(
self.get_outputs_meta()[idx], self.device
),
)
for idx, dst_list in act_send_info.items()
]
Expand Down Expand Up @@ -1442,9 +1454,13 @@ def _validate_stage_shapes(pipeline_stages: List[PipelineStage]):
_create_metadata_tensor(device=stage.device)
for _ in range(stage.group_size)
]
expected_outputs = stage.outputs
stage_output = _create_metadata_tensor(expected_outputs, device=stage.device)
dist.all_gather(tensor_list, stage_output)
outputs_meta = stage.get_outputs_meta()
# TODO, (1) are we deleting output validation when we move to shape inference?
# (2) if not, we should support multiple outputs
assert (
len(outputs_meta) == 1
), f"validation logic assumes single output, got {len(outputs_meta)} outputs "
dist.all_gather(tensor_list, outputs_meta[0])
stage_output_shapes = [
_extract_metadata_from_tensor(tensor) for tensor in tensor_list
]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
mode = 'none' # ['none', 'selective', 'full']
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down
Loading

0 comments on commit d79c77f

Please sign in to comment.