-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP8 Triton matmul code silently requires contiguous tensors #2713
Comments
CC @choutim |
Hello @rationalism, thank you for your questions. These triton-lang/triton#3952 and pytorch/pytorch#125437 should be related. |
Summary: This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases. The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: pytorch#2713 Adding explicit output strides to the kernel resolves the issue. I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale. Reviewed By: jianyuh Differential Revision: D60535956
I think this issue should be resolved in #2919. The quantization kernel in triton was writing output using the same strides as the input but returning a contiguous tensor. This effectively transposed the output tensor. After the fix, it should always return a contiguous output in the proper layout. |
Summary: This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases. The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: pytorch#2713 Adding explicit output strides to the kernel resolves the issue. I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale. Reviewed By: jianyuh Differential Revision: D60535956
Summary: Pull Request resolved: pytorch#2919 This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases. The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: pytorch#2713 Adding explicit output strides to the kernel resolves the issue. I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale. Reviewed By: jianyuh Differential Revision: D60535956
Summary: Pull Request resolved: #2919 This diff fixes an issue where our triton fp8 quantize functions didnt properly handle non-contiguous inputs. Specifically, they write to the output tensor using the same strides as the input, when the output is always allocated as contiguous. This resulted in the output being unintentionally transposed in some cases. The result of this issue was that non-contiguous inputs would run fine but produce silently transposed outputs. It was noted in github here: #2713 Adding explicit output strides to the kernel resolves the issue. I also found a small issue with D59248142 where scaling wouldnt be applied when the number of elements was smaller than the blocksize. This caused fp8_gemm_test to fail. I resolved it by extending the check for when to scale. Reviewed By: jianyuh Differential Revision: D60535956 fbshipit-source-id: 0c449e921e2703f2275e24028238f83fec1c0427
Hello! Thank you very much for this FP8 rowwise matmul code, it's been extremely helpful. However, there is a subtle bug/hidden requirement when eg. calling this code here:
FBGEMM/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py
Line 97 in 735f27b
This works great, but only if the second matrix is contiguous in transposed format (eg. for M, N, K equal to (4,096, 2,048, 1,024), the second matrix must be contiguous in the shape (2,048, 1,024)). If it's not contiguous, the matmul will finish, but the results will be numerically nonsensical.
The text was updated successfully, but these errors were encountered: