Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] GaLore example #488

Draft
wants to merge 1 commit into
base: gh/awgu/10/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,30 @@ def __init__(self):
action="store_true",
help="Whether the fused implementation(CUDA only) is used.",
)
self.parser.add_argument(
"--optimizer.galore_rank", type=int, default=128, help="GaLore rank"
)
self.parser.add_argument(
"--optimizer.galore_update_proj_gap",
type=int,
default=200,
help="GaLore update projection gap",
)
self.parser.add_argument(
"--optimizer.galore_scale", type=float, default=1.0, help="GaLore scale"
)
self.parser.add_argument(
"--optimizer.galore_proj_type",
type=str,
default="std",
help="GaLore projection type",
)
self.parser.add_argument(
"--optimizer.galore_in_backward",
default=False,
action="store_true",
help="Whether to apply GaLore in backward"
)

# training configs
self.parser.add_argument(
Expand Down
22 changes: 11 additions & 11 deletions torchtitan/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ def linear_warmup_linear_decay(


def get_lr_schedulers(optimizers, job_config: JobConfig):
def _get_lr_scheduler(optimizer):
"""Build a linear warmup and linear decay scheduler"""
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(
linear_warmup_linear_decay, warmup_steps, decay_steps
)
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return warmup_scheduler

class SchedulersContainer:
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""

Expand All @@ -54,5 +44,15 @@ def step(self):
schedulers.step()

return SchedulersContainer(
[_get_lr_scheduler(optimizer) for optimizer in optimizers]
[_get_lr_scheduler(job_config, optimizer) for optimizer in optimizers]
)

def _get_lr_scheduler(job_config: JobConfig, optimizer):
"""Build a linear warmup and linear decay scheduler"""
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(
linear_warmup_linear_decay, warmup_steps, decay_steps
)
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return warmup_scheduler
210 changes: 210 additions & 0 deletions torchtitan/optims/galore_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Credit: https://github.com/jiaweizzhao/GaLore/tree/master
(copied over and condensed for convenience)
"""

import math
from typing import Callable, Iterable, Tuple

import torch
import torch.nn as nn

from torchtitan.logging_utils import logger


class GaLoreAdamW(torch.optim.Optimizer):
def __init__(
self,
params: Iterable[nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)"
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)"
)
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"correct_bias": correct_bias,
}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure: Callable = None):
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad

state = self.state[p]
if "step" not in state:
state["step"] = 0
if "dim" not in group:
group["dim"] = 2

if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)
grad = state["projector"].project(grad, state["step"])

if "exp_avg" not in state:
state["exp_avg"] = torch.zeros_like(
grad, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
grad, memory_format=torch.preserve_format
)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]

state["step"] += 1
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]:
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = (
step_size * math.sqrt(bias_correction2) / bias_correction1
)

norm_grad = exp_avg / denom

if "rank" in group:
norm_grad = state["projector"].project_back(norm_grad)

p.add_(norm_grad, alpha=-step_size)

if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

return loss


class GaLoreProjector:
def __init__(
self,
rank: int,
update_proj_gap: int = 200,
scale: float = 1.0,
proj_type: str = "std",
):
self.rank = rank
self.update_proj_gap = update_proj_gap
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type

def project(self, full_rank_grad: torch.Tensor, iter_idx: int):
if self.proj_type == "std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right"
)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left"
)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "reverse_std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left"
)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
else:
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right"
)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "right":
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="right"
)
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "left":
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="left"
)
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "full":
if self.ortho_matrix is None or iter_idx % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
full_rank_grad, self.rank, type="full"
)
low_rank_grad = (
torch.matmul(self.ortho_matrix[0].t(), full_rank_grad)
@ self.ortho_matrix[1].t()
)

return low_rank_grad

def project_back(self, low_rank_grad: torch.Tensor):
if self.proj_type == "std":
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == "reverse_std":
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == "right":
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == "left":
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == "full":
full_rank_grad = (
torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
)

return full_rank_grad * self.scale

def get_orthogonal_matrix(self, param: torch.Tensor, rank: int, type: str):
U, s, Vh = torch.linalg.svd(param.detach(), full_matrices=False)
if type == "right":
B = Vh[:rank, :].to(device=param.device, dtype=param.dtype)
return B
elif type == "left":
A = U[:, :rank].to(device=param.device, dtype=param.dtype)
return A
elif type == "full":
A = U[:, :rank].to(device=param.device, dtype=param.dtype)
B = Vh[:rank, :].to(device=param.device, dtype=param.dtype)
return [A, B]
else:
raise ValueError(f"type should be left, right, or full but got {type}")
59 changes: 59 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,65 @@ def _build_optimizer(model):
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
elif name == "GaLoreAdamW":
from torchtitan.optims.galore_adamw import GaLoreAdamW

optimizer_kwargs.pop("fused")
optimizer_kwargs.pop("foreach")
galore_kwargs = {}
galore_kwargs["rank"] = job_config.optimizer.galore_rank
galore_kwargs["update_proj_gap"] = (
job_config.optimizer.galore_update_proj_gap
)
galore_kwargs["scale"] = job_config.optimizer.galore_scale
galore_kwargs["proj_type"] = job_config.optimizer.galore_proj_type
nongalore_params = []
galore_params = []
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
if (
isinstance(module, torch.nn.Linear)
and "weight" in param_name
and (
"attention" in module_name or "feed_forward" in module_name
)
):
galore_params.append(param)
else:
nongalore_params.append(param)
if not job_config.optimizer.galore_in_backward:
param_groups = [
{"params": nongalore_params},
{"params": galore_params, **galore_kwargs},
]
optimizer = GaLoreAdamW(param_groups, **optimizer_kwargs)
return optimizer
else:
from torchtitan.lr_scheduling import _get_lr_scheduler

param_to_optim: Dict[nn.Parameter, torch.optim.Optimizer] = {}
for param in nongalore_params:
param_to_optim[param] = GaLoreAdamW([param], **optimizer_kwargs)
for param in galore_params:
param_group = [{"params": [param], **galore_kwargs}]
param_to_optim[param] = GaLoreAdamW(param_group, **optimizer_kwargs)

param_to_scheduler: Dict[nn.Parameter, torch.optim.LRScheduler] = {}
for param, optim in param_to_optim.items():
param_to_scheduler[param] = _get_lr_scheduler(job_config, optim)

def optimizer_hook(param: torch.nn.Parameter) -> None:
if param.grad is None:
return
optim = param_to_optim[param]
optim.step()
optim.zero_grad()
param_to_scheduler[param].step()

for param in param_to_optim:
param.register_post_accumulate_grad_hook(optimizer_hook)

return GaLoreAdamW([torch.empty(0)], **optimizer_kwargs)
else:
raise NotImplementedError(f"Optimizer {name} not added.")

Expand Down
Loading