Skip to content

Commit

Permalink
Refactored activation checkpointing
Browse files Browse the repository at this point in the history
ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d
Pull Request resolved: #447
  • Loading branch information
awgu committed Jul 10, 2024
1 parent c7a6a3e commit bc3ec02
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 52 deletions.
7 changes: 4 additions & 3 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,11 @@ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
args_dict[first_level_key][second_level_key] = v
return args_dict

def _validate_config(self) -> bool:
def _validate_config(self) -> None:
# TODO: Add more mandatory validations
assert self.model.name and self.model.flavor and self.model.tokenizer_path
return True
assert self.model.name
assert self.model.flavor
assert self.model.tokenizer_path

def parse_args_from_command_line(
self, args_list
Expand Down
80 changes: 31 additions & 49 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# this file applies the PTD parallelisms and various training techniques to the
# llama model, i.e. activation checkpointing, etc.
# This file applies the PT-D parallelisms and various training techniques (e.g.
# activation checkpointing and compile) to the Llama model.

import copy
from collections import defaultdict
Expand All @@ -17,7 +17,6 @@
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
from torch.distributed.tensor.parallel import (
Expand All @@ -28,8 +27,6 @@
SequenceParallel,
)

from torch.utils.checkpoint import checkpoint

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
Expand All @@ -43,10 +40,25 @@
}


# Uses PTD FSDP AC wrapper
# currently selective per op and per layer checkpointing are supported
def checkpoint_wrapper(module, config):
if config.mode == "selective" and config.selective_ac_option == "op":
def checkpoint_wrapper(module: torch.nn.Module, ac_config):
valid_ac_modes = ("full", "selective")
if ac_config.mode not in valid_ac_modes:
raise ValueError(
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
)

if ac_config.mode == "full":
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)

assert ac_config.mode == "selective", f"{ac_config.mode}"
use_op_sac = ac_config.selective_ac_option == "op"
use_layer_sac = ac_config.selective_ac_option.isdigit()
if not use_op_sac and not use_layer_sac:
raise ValueError(
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
f"Valid options: 'op' or a positive int representing layer frequency"
)
if use_op_sac:
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
Expand Down Expand Up @@ -76,53 +88,23 @@ def selective_checkpointing_context_fn():

return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
context_fn=selective_checkpointing_context_fn,
use_reentrant=False,
preserve_rng_state=False,
)
elif config.mode == "full":
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
)

elif config.mode == "selective" and config.selective_ac_option.isdigit():
"""enables selective checkpointing of candidate layers.
Usage:
'selective_ac_option' with a positive 'int' value in config controls which layers to checkpoint.
1 == checkpointing every one (all).
2 == checkpoint every 2nd one
"""
ac_freq = int(config.selective_ac_option)
assert (
ac_freq >= 0
), f"selective layer AC policy (ac_freq) expects a positive integer, received {ac_freq}"

checkpoint_wrapper.__dict__.setdefault("_count", 0)

checkpoint_wrapper._count += 1
if not ac_freq or checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
elif use_layer_sac:
# Checkpoint every `ac_freq` of the modules passed to this function
ac_freq = int(ac_config.selective_ac_option)
if ac_freq <= 0:
raise ValueError(
f"Selective layer AC expects a positive int as selective_ac_option but got {ac_freq}"
)
# skip activation checkpointing and store activations for this layer
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
ptd_checkpoint_wrapper._count += 1
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
else:
return module

else:
raise NotImplementedError(
"Unknown AC type or AC config. Only selective op and selective layer ac implemented currently."
)


def get_tp_parallel_strategy(
job_config: JobConfig,
Expand Down

0 comments on commit bc3ec02

Please sign in to comment.