-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] Unable to Run FPX Weights #967
Comments
FPx quantization is backed by a custom CUDA kernel, so it is not available to ROCm. https://github.com/pytorch/ao/tree/main/torchao/csrc/cuda/fp6_llm It's strange that it runs with bfloat16 though, so perhaps it is slow precisely because it doesn't use the CUDA kernel. I don't know ROCm well enough, but maybe it's not so hard to port it to ROCm. |
It actually compiles something when I install from source. I see 5 threads light up. I thought torch used the hipify script for C extensions to try and auto convert code? Usually if something isn't supported by ROCm though it'll be caught when the wheel builds I thought. Additionally the error is different when using the source compiled or pip wheel. I can fetch the pip version later but it's a lot more boring essentially just saying that the function doesn't exist. |
Interesting. I don't know much about how PyTorch handle building for ROCm. Can you run this script? https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_fp6.py It will help to verify if you can run the FPx kernel correctly. |
Same exact traceback as my original post. The one example I know of it working on both rocm and cuda is exllama. It uses torch cpp_extensions in |
Compiling ao from source using
pip install git+https://github.com/pytorch/ao.git
results in a very fun throwwhen running FPX weights using the script below
Setup is 1x 7900XTX on torch 2.5+rocm62. All other quantizations work just fine, with the exception of
float8_dynamic_activation_float8_weight
because gfx11 currently does not implement torch's_scaled_mm()
functionUsing
bfloat16
as the base dtype instead actually does run but it's wicked slow from conversions. The floatx readme states to usefloat16
so I assume that's the correct way.Python traceback
traceback.txt
The text was updated successfully, but these errors were encountered: