From eb90941398454fc8d81e00b14ba503ab7da22972 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 29 Jul 2024 07:35:53 -0700 Subject: [PATCH] [Not for land] GaLore example [ghstack-poisoned] --- torchtitan/config_manager.py | 24 ++++ torchtitan/lr_scheduling.py | 22 ++-- torchtitan/optims/galore_adamw.py | 210 ++++++++++++++++++++++++++++++ train.py | 59 +++++++++ 4 files changed, 304 insertions(+), 11 deletions(-) create mode 100644 torchtitan/optims/galore_adamw.py diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 9a086830..a554da9f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -187,6 +187,30 @@ def __init__(self): action="store_true", help="Whether the fused implementation(CUDA only) is used.", ) + self.parser.add_argument( + "--optimizer.galore_rank", type=int, default=128, help="GaLore rank" + ) + self.parser.add_argument( + "--optimizer.galore_update_proj_gap", + type=int, + default=200, + help="GaLore update projection gap", + ) + self.parser.add_argument( + "--optimizer.galore_scale", type=float, default=1.0, help="GaLore scale" + ) + self.parser.add_argument( + "--optimizer.galore_proj_type", + type=str, + default="std", + help="GaLore projection type", + ) + self.parser.add_argument( + "--optimizer.galore_in_backward", + default=False, + action="store_true", + help="Whether to apply GaLore in backward" + ) # training configs self.parser.add_argument( diff --git a/torchtitan/lr_scheduling.py b/torchtitan/lr_scheduling.py index 9f766268..678ff28f 100644 --- a/torchtitan/lr_scheduling.py +++ b/torchtitan/lr_scheduling.py @@ -33,16 +33,6 @@ def linear_warmup_linear_decay( def get_lr_schedulers(optimizers, job_config: JobConfig): - def _get_lr_scheduler(optimizer): - """Build a linear warmup and linear decay scheduler""" - warmup_steps = int(job_config.training.warmup_steps) - decay_steps = float(max(1, job_config.training.steps - warmup_steps)) - lr_lambda = functools.partial( - linear_warmup_linear_decay, warmup_steps, decay_steps - ) - warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) - return warmup_scheduler - class SchedulersContainer: """Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" @@ -54,5 +44,15 @@ def step(self): schedulers.step() return SchedulersContainer( - [_get_lr_scheduler(optimizer) for optimizer in optimizers] + [_get_lr_scheduler(job_config, optimizer) for optimizer in optimizers] + ) + +def _get_lr_scheduler(job_config: JobConfig, optimizer): + """Build a linear warmup and linear decay scheduler""" + warmup_steps = int(job_config.training.warmup_steps) + decay_steps = float(max(1, job_config.training.steps - warmup_steps)) + lr_lambda = functools.partial( + linear_warmup_linear_decay, warmup_steps, decay_steps ) + warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return warmup_scheduler diff --git a/torchtitan/optims/galore_adamw.py b/torchtitan/optims/galore_adamw.py new file mode 100644 index 00000000..99931ba2 --- /dev/null +++ b/torchtitan/optims/galore_adamw.py @@ -0,0 +1,210 @@ +""" +Credit: https://github.com/jiaweizzhao/GaLore/tree/master +(copied over and condensed for convenience) +""" + +import math +from typing import Callable, Iterable, Tuple + +import torch +import torch.nn as nn + +from torchtitan.logging_utils import logger + + +class GaLoreAdamW(torch.optim.Optimizer): + def __init__( + self, + params: Iterable[nn.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)" + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "correct_bias": correct_bias, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Callable = None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + + state = self.state[p] + if "step" not in state: + state["step"] = 0 + if "dim" not in group: + group["dim"] = 2 + + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + update_proj_gap=group["update_proj_gap"], + scale=group["scale"], + proj_type=group["proj_type"], + ) + grad = state["projector"].project(grad, state["step"]) + + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like( + grad, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + grad, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + step_size = group["lr"] + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = ( + step_size * math.sqrt(bias_correction2) / bias_correction1 + ) + + norm_grad = exp_avg / denom + + if "rank" in group: + norm_grad = state["projector"].project_back(norm_grad) + + p.add_(norm_grad, alpha=-step_size) + + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + return loss + + +class GaLoreProjector: + def __init__( + self, + rank: int, + update_proj_gap: int = 200, + scale: float = 1.0, + proj_type: str = "std", + ): + self.rank = rank + self.update_proj_gap = update_proj_gap + self.scale = scale + self.ortho_matrix = None + self.proj_type = proj_type + + def project(self, full_rank_grad: torch.Tensor, iter_idx: int): + if self.proj_type == "std": + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right" + ) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + else: + if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left" + ) + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == "reverse_std": + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left" + ) + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + else: + if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right" + ) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == "right": + if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right" + ) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == "left": + if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left" + ) + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == "full": + if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="full" + ) + low_rank_grad = ( + torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) + @ self.ortho_matrix[1].t() + ) + + return low_rank_grad + + def project_back(self, low_rank_grad: torch.Tensor): + if self.proj_type == "std": + if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + else: + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == "reverse_std": + if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + else: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == "right": + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == "left": + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == "full": + full_rank_grad = ( + torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] + ) + + return full_rank_grad * self.scale + + def get_orthogonal_matrix(self, param: torch.Tensor, rank: int, type: str): + U, s, Vh = torch.linalg.svd(param.detach(), full_matrices=False) + if type == "right": + B = Vh[:rank, :].to(device=param.device, dtype=param.dtype) + return B + elif type == "left": + A = U[:, :rank].to(device=param.device, dtype=param.dtype) + return A + elif type == "full": + A = U[:, :rank].to(device=param.device, dtype=param.dtype) + B = Vh[:rank, :].to(device=param.device, dtype=param.dtype) + return [A, B] + else: + raise ValueError(f"type should be left, right, or full but got {type}") diff --git a/train.py b/train.py index b7eee302..916c4a85 100644 --- a/train.py +++ b/train.py @@ -116,6 +116,65 @@ def _build_optimizer(model): optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) elif name == "AdamW": optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs) + elif name == "GaLoreAdamW": + from torchtitan.optims.galore_adamw import GaLoreAdamW + + optimizer_kwargs.pop("fused") + optimizer_kwargs.pop("foreach") + galore_kwargs = {} + galore_kwargs["rank"] = job_config.optimizer.galore_rank + galore_kwargs["update_proj_gap"] = ( + job_config.optimizer.galore_update_proj_gap + ) + galore_kwargs["scale"] = job_config.optimizer.galore_scale + galore_kwargs["proj_type"] = job_config.optimizer.galore_proj_type + nongalore_params = [] + galore_params = [] + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + if ( + isinstance(module, torch.nn.Linear) + and "weight" in param_name + and ( + "attention" in module_name or "feed_forward" in module_name + ) + ): + galore_params.append(param) + else: + nongalore_params.append(param) + if not job_config.optimizer.galore_in_backward: + param_groups = [ + {"params": nongalore_params}, + {"params": galore_params, **galore_kwargs}, + ] + optimizer = GaLoreAdamW(param_groups, **optimizer_kwargs) + return optimizer + else: + from torchtitan.lr_scheduling import _get_lr_scheduler + + param_to_optim: Dict[nn.Parameter, torch.optim.Optimizer] = {} + for param in nongalore_params: + param_to_optim[param] = GaLoreAdamW([param], **optimizer_kwargs) + for param in galore_params: + param_group = [{"params": [param], **galore_kwargs}] + param_to_optim[param] = GaLoreAdamW(param_group, **optimizer_kwargs) + + param_to_scheduler: Dict[nn.Parameter, torch.optim.LRScheduler] = {} + for param, optim in param_to_optim.items(): + param_to_scheduler[param] = _get_lr_scheduler(job_config, optim) + + def optimizer_hook(param: torch.nn.Parameter) -> None: + if param.grad is None: + return + optim = param_to_optim[param] + optim.step() + optim.zero_grad() + param_to_scheduler[param].step() + + for param in param_to_optim: + param.register_post_accumulate_grad_hook(optimizer_hook) + + return GaLoreAdamW([torch.empty(0)], **optimizer_kwargs) else: raise NotImplementedError(f"Optimizer {name} not added.")