Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalid kwarg fused passed to bitsandbytes AdamW8bit #2152

Open
mlazos opened this issue Dec 12, 2024 · 7 comments
Open

Invalid kwarg fused passed to bitsandbytes AdamW8bit #2152

mlazos opened this issue Dec 12, 2024 · 7 comments
Assignees
Labels
better engineering Tasks which help improve eng productivity e.g. building tools, cleaning up code, writing docs

Comments

@mlazos
Copy link

mlazos commented Dec 12, 2024

Hi, when running the following command:
tune run lora_finetune_single_device --config llama3/8B_lora_single_device model.lora_rank=16 optimizer=bitsandbytes.optim.AdamW8bit gradient_accumulation_steps=4 tokenizer.max_seq_len=2048 max_steps_per_epoch=100 model.lora_attn_modules="['q_proj','k_proj','v_proj','output_proj']" model.apply_lora_to_mlp=True log_peak_memory_stats=True compile=True checkpointer.checkpoint_dir=checkpoints/original tokenizer.path=checkpoints/original/tokenizer.model checkpointer.output_dir=checkpoints/original

Which returns this stack trace.

It looks like we unconditionally pass fused as a kwarg to the optimizer even though the bits and bytes optimizer doesn't have this kwarg

Related issue:#1998

Version info:
Pytorch: 1b3f8b75896720e88362cbec7db32abc52afa83e
Torchtune: f2bd4bc
Torchao: 039cef4ad546716aa04cd54c461feb173f7fe403

@gau-nernst
Copy link
Contributor

You probably have to manually delete the fused: True key in the config file. Don't think torchtune CLI can remove a specific key from the config?

@mlazos
Copy link
Author

mlazos commented Dec 12, 2024

Makes sense, but shouldn't you throw a nicer error in this case (maybe giving the instructions you just gave)? Or just not pass fused by default? If you're compiling the optimizer I wouldn't expect fused to matter.

As a first time user seeing something like this made me think there was a bug in the framework.

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 12, 2024

thats a good feedback @mlazos . I will check what we can do to make it more intuitive. We add fused as default because it is faster is some scenarios.

To remove it in the cli, you can do tune run <config> ~optimizer.fused

the '~' will delete it.

if you hit any other issues, let us know!

@felipemello1 felipemello1 added the better engineering Tasks which help improve eng productivity e.g. building tools, cleaning up code, writing docs label Dec 12, 2024
@felipemello1 felipemello1 self-assigned this Dec 12, 2024
@gau-nernst
Copy link
Contributor

(Wanted to tag @felipemello1 but saw that you have already replied so I will skip my comment on UX issue)

PyTorch optimizers use for-each implementation by default (https://pytorch.org/docs/stable/optim.html#algorithms, after the list of optimizers, you can see the description for for-each and fused implementations).

For-each implementation actually consumes quite a lot more VRAM because it materializes some intermediate tensors (and slightly slower than fused). This is especially bad for big models like LLMs nowadays (though for this specific LoRA recipe, it wouldn't be much). I don't exactly have the historical context, but I'm guessing the for-each implementation was for CNNs in the past where you have a lot of small tensors, so for-each implementation cuts down the launch overhead significantly. Still waiting for fused implementation to be the new default 😄

@mlazos
Copy link
Author

mlazos commented Dec 12, 2024

@gau-nernst yeah agreed that foreach is not as performant and consumes more memory compared to a fully fused kernel. The crux of my comment is that we could torch.compile the foreach optimizer and generate fused kernels automatically today, even for the optimizers that don't have a fused implementation, might be worth adding an option for that as well in the config. (I'm tooting my own horn because I added that capability to torch.compile 😉)

Other than that thanks for the quick replies!

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 21, 2024

@mlazos , n00b question: how can we do it? Would torch.compile(optimizer) work?

@mlazos
Copy link
Author

mlazos commented Dec 21, 2024

@mlazos , n00b question: how can we do it? Would torch.compile(optimizer) work?

Since the optimizer isn't a module this doesn't work, you compile the step so torch.compile(optim.step)

Now that you mention this though I could add this, but yeah for now just compile the step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better engineering Tasks which help improve eng productivity e.g. building tools, cleaning up code, writing docs
Projects
None yet
Development

No branches or pull requests

3 participants