Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPUOffloadOptimizer incompatible with learning rate schedulers #959

Open
bghira opened this issue Sep 26, 2024 · 5 comments
Open

CPUOffloadOptimizer incompatible with learning rate schedulers #959

bghira opened this issue Sep 26, 2024 · 5 comments

Comments

@bghira
Copy link

bghira commented Sep 26, 2024

in get_polynomial_decay_schedule_with_warmup

    lr_init = optimizer.defaults["lr"]
AttributeError: 'CPUOffloadOptimizer' object has no attribute 'defaults'

I'm not sure why this isn't working / exposed for external calls, as it works without the offload optimizer class.

@bghira
Copy link
Author

bghira commented Sep 26, 2024

constant:

File "site-packages/torch/optim/lr_scheduler.py", line 93, in __init__
    raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
TypeError: CPUOffloadOptimizer is not an Optimizer

@bghira
Copy link
Author

bghira commented Sep 26, 2024

2024-09-26 14:19:28,114 [INFO] cls: <class 'torchao.prototype.low_bit_optim.adam.AdamW8bit'>, settings: {'betas': (0.9, 0.999), 'weight_decay': 0.01, 'eps': 1e-06}
2024-09-26 14:19:28,117 [INFO] Optimizer arguments={'lr': 0.0001, 'betas': (0.9, 0.999), 'weight_decay': 0.01, 'eps': 1e-06}
2024-09-26 14:19:28,148 [INFO] Loading constant learning rate scheduler with 100 warmup steps
2024-09-26 14:19:28,148 [INFO] Using generic 'constant' learning rate scheduler.
CPUOffloadOptimizer is not an Optimizer

other info about the optim config. not sure it is helpful.

@gau-nernst
Copy link
Collaborator

Hello @bghira, for CPU offload optimizer, you have to update the LR manually.

#584 (comment)

Yea perhaps we should update the doc to make it clearer. Lmk if you still have problems

@gau-nernst
Copy link
Collaborator

@bghira We can continue the discussion here (instead at the PR) for better visibility

i think it should not be referred to as a drop-in replacement then

I don't think CPU offload optimizer is mentioned as a drop-in replacement (but it doesn't mean that it shouldn't be). There are already many other caveats that I believe users should be aware of.

Regarding the LR schedule issue, like I mentioned in the PR, the issue is that PyTorch's LR scheduler is hard-coded to check for optimizer subclass-ness, and I don't want to make CPUOffloadOptimizer a subclass of torch.optim.Optimizer. Personally I don't use PyTorch's LR scheduler so it was not an issue for me. But perhaps if it makes life easier for most other people, we can make it subclass torch.optim.Optimizer (and hope users don't run into other issues with this subclass-ness).

Just curious, from your perspective, would it be too much to ask users to also update the LR schedule code? Since you would already need to modify some code to use the CPU offload optimizer, it doesn't seem much to also change the LR schedule code.

@bghira
Copy link
Author

bghira commented Sep 27, 2024

it would need to be propagated up to the huggingface library eventually but i have plenty of local overrides

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants