diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2bd6e370..cba89436 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -241,6 +241,12 @@ def __init__(self): action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", ) + self.parser.add_argument( + "--experimental.offload_activations", + default=False, + action="store_true", + help="Whether to offload activations", + ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", default=False, diff --git a/torchtitan/parallelisms/offload_utils.py b/torchtitan/parallelisms/offload_utils.py new file mode 100644 index 00000000..f5a470e8 --- /dev/null +++ b/torchtitan/parallelisms/offload_utils.py @@ -0,0 +1,146 @@ +from typing import Dict, Optional, Tuple + +import torch + +from torch.autograd.graph import saved_tensors_hooks + + +HandleKey = Tuple[torch.device, torch.Tensor] + + +class Handle: + def __init__( + self, + device_tensor: torch.Tensor, + offload_stream: torch.cuda.Stream, + ): + if not torch.is_tensor(device_tensor): + raise ValueError(f"Expects tensor but got {device_tensor}") + self.device_tensor: Optional[torch.Tensor] = device_tensor + self.cpu_tensor: Optional[torch.Tensor] = None + self.offload_stream = offload_stream + self.d2h_event: Optional[torch.cuda.Event] = None + self.h2d_event: Optional[torch.cuda.Event] = None + self.device: torch.device = device_tensor.device + + def copy_d2h_async(self) -> None: + current_stream = torch.cuda.current_stream() + self.offload_stream.wait_stream(current_stream) + with torch.cuda.stream(self.offload_stream): + self.cpu_tensor = self.device_tensor.to( + torch.device("cpu"), non_blocking=True + ) + self.d2h_event = self.offload_stream.record_event() + + def copy_h2d_async(self) -> None: + if self.device_tensor is not None: + return + assert self.cpu_tensor is not None + self.device_tensor = torch.empty_like(self.cpu_tensor, device=self.device) + self.offload_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.offload_stream): + self.device_tensor.copy_(self.cpu_tensor, non_blocking=True) + self.h2d_event = self.offload_stream.record_event() + + def wait_for_d2h(self): + if self.d2h_event: + torch.cuda.current_stream().wait_event(self.d2h_event) + self.device_tensor = None + + def wait_for_h2d(self): + if self.h2d_event: + torch.cuda.current_stream().wait_event(self.h2d_event) + self.cpu_tensor = None + + +class offload_to_cpu(saved_tensors_hooks): + """ + This represents a saved tensors hooks context that offloads activations to + CPU in forward and un-offloads them from CPU in backward. + + In forward, the D2H copy is always async. Device memory is freed when the + user calls :meth:`wait_for_d2h`, which should be done after the compute + with which to overlap has been issued. + + In backward, the H2D copy defaults to sync. However, the user may call + :meth:`copy_h2d_async` to issue the H2D copy as async before the compute + with which to overlap has been issued. When the activation is used in + backward, we will wait for that H2D copy without user intervention. + + The D2H and H2D copies always used pinned memory, so the user should take + care to ensure sufficient CPU RAM to be pinned. Otherwise the program can + become slow or freeze. The first few iterations will be much slower due to + repeated ``cudaHostAlloc`` calls to warmup the CPU caching allocator. + """ + + def __init__(self, offload_stream: torch.cuda.Stream): + self.handle_key_to_handle: Dict[HandleKey, Handle] = {} + self.offload_stream = offload_stream + + def pack_to_cpu(tensor: torch.Tensor): + if tensor.device.type == "cpu": + return (tensor.device, tensor) + + device_tensor = tensor + del tensor + # TODO: Need a way to decide whether to offload this tensor or not + # that might need to be a function of the op constructing this + # tensor, pipeline parallel rank, etc. + if device_tensor.numel() < (14336 * 8192): # (FFN dim * seq_len) for 8B + return (device_tensor.device, device_tensor) + + handle = Handle(device_tensor, offload_stream) + handle.copy_d2h_async() + + assert handle.cpu_tensor is not None + handle_key = (device_tensor.device, handle.cpu_tensor) + self.handle_key_to_handle[handle_key] = handle + + return handle_key + + def unpack_from_cpu(handle_key: HandleKey): + device, tensor = handle_key + if tensor.device == device: + return tensor + + assert tensor.device == torch.device("cpu"), f"{tensor.device}" + cpu_tensor = tensor + del tensor + + handle = self.handle_key_to_handle.get(handle_key, None) + if handle is None: + raise RuntimeError(f"Handle missing for {handle_key}") + + handle.wait_for_h2d() + if handle.device_tensor is not None: + device_tensor = handle.device_tensor + handle.device_tensor = None + return device_tensor + + # Fallback to non-overlapped H2D copy + device_tensor = cpu_tensor.to(device, non_blocking=True) + assert handle.cpu_tensor is None + return device_tensor + + super().__init__(pack_to_cpu, unpack_from_cpu) + + def wait_for_d2h(self): + for handle in self.handle_key_to_handle.values(): + handle.wait_for_d2h() + + def copy_h2d_async(self): + # HACK: Sleeping for 1 ms before copy H2D helps avoid the no-overlap + # issue for `reshard_after_forward=True` where AG copy-out's H2D copy + # serializes after these H2D copies, preventing overlap. + # self.offload_stream.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(self.offload_stream): + # from torch.testing._internal.common_utils import get_cycles_per_ms + # torch.cuda._sleep(int(get_cycles_per_ms())) + for handle in self.handle_key_to_handle.values(): + handle.copy_h2d_async() + + def __enter__(self): + super().__enter__() + # Override this to return `self` so that the context can be saved like + # with `offload_to_cpu(offload_stream) as ctx:` + return self diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index ec0f6763..72155845 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -479,6 +479,7 @@ def apply_dp( # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # reshard_after_forward = False fully_shard( transformer_block, **fsdp_config, @@ -492,6 +493,64 @@ def apply_dp( return model +def apply_offload(model: nn.Module): + import pynvml + from torchtitan.parallelisms.offload_utils import offload_to_cpu + + # Set the CPU affinity based on the GPU ID + uuids = torch.cuda._raw_device_uuid_nvml() + dev_id = torch.cuda.current_device() + assert dev_id >= 0 and dev_id < len(uuids), f"{dev_id} {len(uuids)}" + uuid = uuids[dev_id] + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByUUID(uuid) + pynvml.nvmlDeviceSetCpuAffinity(handle) + + offload_stream = torch.cuda.Stream() + layer_id_to_ctx = {} + + def register_forward_hooks(module: nn.Module): + + def forward_pre_hook(module, args): + ctx = offload_to_cpu(offload_stream) + ctx.__enter__() + module.offload_ctx = ctx + return args + + def forward_hook(module, input, output): + module.offload_ctx.__exit__(None, None, None) + # Wait on the previous forward layer's D2H copies to free memory + layer_id = module.layer_id + if layer_id > 0: + layer_id_to_ctx[layer_id - 1].wait_for_d2h() + layer_id_to_ctx[layer_id] = module.offload_ctx + return output + + module.register_forward_pre_hook(forward_pre_hook) + module.register_forward_hook(forward_hook) + + def register_backward_hooks(module: nn.Module): + + def backward_hook(module: nn.Module, grad_output: torch.Tensor): + # Prefetch the next backward layer's H2D copies to overlap + layer_id = module.layer_id + target_layer_id = layer_id - 1 + if target_layer_id >= 0: + with torch.profiler.record_function( + f"copy_h2d_async for {target_layer_id}" + ): + layer_id_to_ctx[target_layer_id].copy_h2d_async() + return + + module.register_full_backward_pre_hook(backward_hook) + + for layer in model.layers.values(): + register_forward_hooks(layer) + register_backward_hooks(layer) + + return model + + def parallelize_llama( model: nn.Module, world_mesh: DeviceMesh, @@ -518,4 +577,7 @@ def parallelize_llama( if parallel_dims.dp_enabled: model = apply_dp(model, world_mesh, parallel_dims, job_config) + if job_config.experimental.offload_activations: + model = apply_offload(model) + return model diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 662b64f8..1c8a5022 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -14,7 +14,7 @@ from torchtitan.logging_utils import logger # the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 +WARMUP = 1 # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 diff --git a/train.py b/train.py index afd1d888..c18adb6c 100644 --- a/train.py +++ b/train.py @@ -400,6 +400,11 @@ def loss_fn(pred, labels): optimizers.step() lr_schedulers.step() + if job_config.experimental.offload_activations: + # NOTE: We need `gc.collect` to ensure that CPU memory is freed + # even though we have no explicit refs left. + gc.collect() + # when fp8 config is on, # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance @@ -445,6 +450,7 @@ def loss_fn(pred, labels): time_data_loading_pct = 100 * np.sum(data_loading_times) / time_delta gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + import psutil metrics = { "loss_metrics/global_avg_loss": global_avg_loss, @@ -460,6 +466,7 @@ def loss_fn(pred, labels): "memory/max_reserved(%)": gpu_mem_stats.max_reserved_pct, "memory/num_alloc_retries": gpu_mem_stats.num_alloc_retries, "memory/num_ooms": gpu_mem_stats.num_ooms, + "cpu mem(%)": psutil.cpu_percent(), } metric_logger.log(metrics, step=train_state.step) @@ -469,7 +476,8 @@ def loss_fn(pred, labels): f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB" f"({gpu_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}wps: {round(wps):,} " - f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + f"{color.magenta}mfu: {mfu:.2f}%{color.reset} " + f"{color.white}cpu mem(%): {psutil.cpu_percent():.2f}%{color.reset}" ) losses_since_last_log.clear()