Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low bit Optimizers & FA-3 #742

Open
asahni04 opened this issue Dec 16, 2024 · 4 comments
Open

Low bit Optimizers & FA-3 #742

asahni04 opened this issue Dec 16, 2024 · 4 comments

Comments

@asahni04
Copy link

  1. hi have there been any tests with fa-3 and low bit optimizers from torchao like FP8adam for 8bit adam? i see divergence in training when resuming a FA-2 checkpoint with FA-3 or when using 8BITADAMW
@fegin
Copy link
Contributor

fegin commented Dec 16, 2024

cc., @weifengpy

@weifengpy
Copy link
Contributor

Hey @asahni04, do you happen to have some breakdown?

  • baseline: load FA-2 checkpoint with FA-2 model, adamw
  • switch to FA-3
  • switch to 8-bit adamw

It helps clarify if it's FA-3 (model state dict) or 8-bit adamw (optim state dict)

@asahni04
Copy link
Author

asahni04 commented Dec 19, 2024

hi @weifengpy sorry for the delayed response, yes

  1. baseline is FA-2 checkpoint with adamw
    2.switching to FA-3 directly for inference (single gpu and multi-gpu TP based) in the model trained on FA-2 leads to broken results. however finetuning from scratch with FA-3 seems to work and give around 30% speedup depending on parallel config
  2. with adam 8-bit the loss seems to diverge after some iterations, tried with various block_ sizes and am using the torchao implementation. any suggestions to help solve it? can it be a error due to TP/DP config??

@gnadathur
Copy link
Contributor

cc: @vkuzo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants