Skip to content

Commit

Permalink
Used partial instead of global vars for LR scheduling
Browse files Browse the repository at this point in the history
ghstack-source-id: 12c4418b0574d93e1441f4ca3d1de79c8aad7a40
Pull Request resolved: #487
  • Loading branch information
awgu committed Jul 29, 2024
1 parent 42f4ff5 commit 12e0ada
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions torchtitan/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,43 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools

from torch.optim.lr_scheduler import LambdaLR
from torchtitan.config_manager import JobConfig

# global states for scheduling
# these are needed as LambdaLR does not support argument passing
_warmup_steps = 200
_decay_steps = 0


def linear_warmup_linear_decay(current_step: int) -> float:
def linear_warmup_linear_decay(
warmup_steps: int, decay_steps: int, current_step: int
) -> float:
"""Computes linear warmup followed by linear decay.
Per LambdaLR requirement, this is accomplished by returning
a multiplicative factor to adjust the learning rate to
create the desired schedule.
"""
if current_step < _warmup_steps:
if current_step < warmup_steps:
# linear warmup
# 0-indexed step, hence + 1 adjustments
current_step += 1
curr_adjustment = float(current_step / (_warmup_steps + 1))
curr_adjustment = float(current_step / (warmup_steps + 1))

else:
# linear decay
normalized_step = _decay_steps - (current_step - _warmup_steps)
curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps
normalized_step = decay_steps - (current_step - warmup_steps)
curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps

return curr_adjustment


def get_lr_schedulers(optimizers, job_config: JobConfig):
def _get_lr_scheduler(optimizer):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = int(job_config.training.warmup_steps)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(
linear_warmup_linear_decay, warmup_steps, decay_steps
)
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return warmup_scheduler

class SchedulersContainer:
Expand Down

0 comments on commit 12e0ada

Please sign in to comment.