Skip to content

Commit

Permalink
[Do not review] Activation offloading
Browse files Browse the repository at this point in the history
ghstack-source-id: 1f53901f927b56c0ff58b81f853e6969cf348b84
Pull Request resolved: #467
  • Loading branch information
awgu committed Jul 18, 2024
1 parent 69fe8de commit 00fa1c1
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 2 deletions.
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ def __init__(self):
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--experimental.offload_activations",
default=False,
action="store_true",
help="Whether to offload activations",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
default=False,
Expand Down
146 changes: 146 additions & 0 deletions torchtitan/parallelisms/offload_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Dict, Optional, Tuple

import torch

from torch.autograd.graph import saved_tensors_hooks


HandleKey = Tuple[torch.device, torch.Tensor]


class Handle:
def __init__(
self,
device_tensor: torch.Tensor,
offload_stream: torch.cuda.Stream,
):
if not torch.is_tensor(device_tensor):
raise ValueError(f"Expects tensor but got {device_tensor}")
self.device_tensor: Optional[torch.Tensor] = device_tensor
self.cpu_tensor: Optional[torch.Tensor] = None
self.offload_stream = offload_stream
self.d2h_event: Optional[torch.cuda.Event] = None
self.h2d_event: Optional[torch.cuda.Event] = None
self.device: torch.device = device_tensor.device

def copy_d2h_async(self) -> None:
current_stream = torch.cuda.current_stream()
self.offload_stream.wait_stream(current_stream)
with torch.cuda.stream(self.offload_stream):
self.cpu_tensor = self.device_tensor.to(
torch.device("cpu"), non_blocking=True
)
self.d2h_event = self.offload_stream.record_event()

def copy_h2d_async(self) -> None:
if self.device_tensor is not None:
return
assert self.cpu_tensor is not None
self.device_tensor = torch.empty_like(self.cpu_tensor, device=self.device)
self.offload_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.offload_stream):
self.device_tensor.copy_(self.cpu_tensor, non_blocking=True)
self.h2d_event = self.offload_stream.record_event()

def wait_for_d2h(self):
if self.d2h_event:
torch.cuda.current_stream().wait_event(self.d2h_event)
self.device_tensor = None

def wait_for_h2d(self):
if self.h2d_event:
torch.cuda.current_stream().wait_event(self.h2d_event)
self.cpu_tensor = None


class offload_to_cpu(saved_tensors_hooks):
"""
This represents a saved tensors hooks context that offloads activations to
CPU in forward and un-offloads them from CPU in backward.
In forward, the D2H copy is always async. Device memory is freed when the
user calls :meth:`wait_for_d2h`, which should be done after the compute
with which to overlap has been issued.
In backward, the H2D copy defaults to sync. However, the user may call
:meth:`copy_h2d_async` to issue the H2D copy as async before the compute
with which to overlap has been issued. When the activation is used in
backward, we will wait for that H2D copy without user intervention.
The D2H and H2D copies always used pinned memory, so the user should take
care to ensure sufficient CPU RAM to be pinned. Otherwise the program can
become slow or freeze. The first few iterations will be much slower due to
repeated ``cudaHostAlloc`` calls to warmup the CPU caching allocator.
"""

def __init__(self, offload_stream: torch.cuda.Stream):
self.handle_key_to_handle: Dict[HandleKey, Handle] = {}
self.offload_stream = offload_stream

def pack_to_cpu(tensor: torch.Tensor):
if tensor.device.type == "cpu":
return (tensor.device, tensor)

device_tensor = tensor
del tensor
# TODO: Need a way to decide whether to offload this tensor or not
# that might need to be a function of the op constructing this
# tensor, pipeline parallel rank, etc.
if device_tensor.numel() < (14336 * 8192): # (FFN dim * seq_len) for 8B
return (device_tensor.device, device_tensor)

handle = Handle(device_tensor, offload_stream)
handle.copy_d2h_async()

assert handle.cpu_tensor is not None
handle_key = (device_tensor.device, handle.cpu_tensor)
self.handle_key_to_handle[handle_key] = handle

return handle_key

def unpack_from_cpu(handle_key: HandleKey):
device, tensor = handle_key
if tensor.device == device:
return tensor

assert tensor.device == torch.device("cpu"), f"{tensor.device}"
cpu_tensor = tensor
del tensor

handle = self.handle_key_to_handle.get(handle_key, None)
if handle is None:
raise RuntimeError(f"Handle missing for {handle_key}")

handle.wait_for_h2d()
if handle.device_tensor is not None:
device_tensor = handle.device_tensor
handle.device_tensor = None
return device_tensor

# Fallback to non-overlapped H2D copy
device_tensor = cpu_tensor.to(device, non_blocking=True)
assert handle.cpu_tensor is None
return device_tensor

super().__init__(pack_to_cpu, unpack_from_cpu)

def wait_for_d2h(self):
for handle in self.handle_key_to_handle.values():
handle.wait_for_d2h()

def copy_h2d_async(self):
# HACK: Sleeping for 1 ms before copy H2D helps avoid the no-overlap
# issue for `reshard_after_forward=True` where AG copy-out's H2D copy
# serializes after these H2D copies, preventing overlap.
# self.offload_stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(self.offload_stream):
# from torch.testing._internal.common_utils import get_cycles_per_ms
# torch.cuda._sleep(int(get_cycles_per_ms()))
for handle in self.handle_key_to_handle.values():
handle.copy_h2d_async()

def __enter__(self):
super().__enter__()
# Override this to return `self` so that the context can be saved like
# with `offload_to_cpu(offload_stream) as ctx:`
return self
62 changes: 62 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def apply_dp(
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
# reshard_after_forward = False
fully_shard(
transformer_block,
**fsdp_config,
Expand All @@ -492,6 +493,64 @@ def apply_dp(
return model


def apply_offload(model: nn.Module):
import pynvml
from torchtitan.parallelisms.offload_utils import offload_to_cpu

# Set the CPU affinity based on the GPU ID
uuids = torch.cuda._raw_device_uuid_nvml()
dev_id = torch.cuda.current_device()
assert dev_id >= 0 and dev_id < len(uuids), f"{dev_id} {len(uuids)}"
uuid = uuids[dev_id]
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByUUID(uuid)
pynvml.nvmlDeviceSetCpuAffinity(handle)

offload_stream = torch.cuda.Stream()
layer_id_to_ctx = {}

def register_forward_hooks(module: nn.Module):

def forward_pre_hook(module, args):
ctx = offload_to_cpu(offload_stream)
ctx.__enter__()
module.offload_ctx = ctx
return args

def forward_hook(module, input, output):
module.offload_ctx.__exit__(None, None, None)
# Wait on the previous forward layer's D2H copies to free memory
layer_id = module.layer_id
if layer_id > 0:
layer_id_to_ctx[layer_id - 1].wait_for_d2h()
layer_id_to_ctx[layer_id] = module.offload_ctx
return output

module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook)

def register_backward_hooks(module: nn.Module):

def backward_hook(module: nn.Module, grad_output: torch.Tensor):
# Prefetch the next backward layer's H2D copies to overlap
layer_id = module.layer_id
target_layer_id = layer_id - 1
if target_layer_id >= 0:
with torch.profiler.record_function(
f"copy_h2d_async for {target_layer_id}"
):
layer_id_to_ctx[target_layer_id].copy_h2d_async()
return

module.register_full_backward_pre_hook(backward_hook)

for layer in model.layers.values():
register_forward_hooks(layer)
register_backward_hooks(layer)

return model


def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
Expand All @@ -518,4 +577,7 @@ def parallelize_llama(
if parallel_dims.dp_enabled:
model = apply_dp(model, world_mesh, parallel_dims, job_config)

if job_config.experimental.offload_activations:
model = apply_offload(model)

return model
2 changes: 1 addition & 1 deletion torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchtitan.logging_utils import logger

# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3
WARMUP = 1

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
Expand Down
10 changes: 9 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,11 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

if job_config.experimental.offload_activations:
# NOTE: We need `gc.collect` to ensure that CPU memory is freed
# even though we have no explicit refs left.
gc.collect()

# when fp8 config is on,
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
Expand Down Expand Up @@ -445,6 +450,7 @@ def loss_fn(pred, labels):
time_data_loading_pct = 100 * np.sum(data_loading_times) / time_delta

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
import psutil

metrics = {
"loss_metrics/global_avg_loss": global_avg_loss,
Expand All @@ -460,6 +466,7 @@ def loss_fn(pred, labels):
"memory/max_reserved(%)": gpu_mem_stats.max_reserved_pct,
"memory/num_alloc_retries": gpu_mem_stats.num_alloc_retries,
"memory/num_ooms": gpu_mem_stats.num_ooms,
"cpu mem(%)": psutil.cpu_percent(),
}
metric_logger.log(metrics, step=train_state.step)

Expand All @@ -469,7 +476,8 @@ def loss_fn(pred, labels):
f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB"
f"({gpu_mem_stats.max_reserved_pct:.2f}%) "
f"{color.blue}wps: {round(wps):,} "
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
f"{color.magenta}mfu: {mfu:.2f}%{color.reset} "
f"{color.white}cpu mem(%): {psutil.cpu_percent():.2f}%{color.reset}"
)

losses_since_last_log.clear()
Expand Down

0 comments on commit 00fa1c1

Please sign in to comment.