-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pallas Kernel Expected Output Shape Error Using Grids #8469
Comments
for pallas bug you will have better luck opening an issue under jax repo, since they own the pallas. |
+1 I see you are using the 2.4 whl. Do you observe the same error pattern with torch_xla nightly (or 2.5)? More recent torch_xla whls use more recent libtpu and jax dependenceis. |
I will Try today with 2.5 :) Update: torch 2.5.1 |
But I only see this bug while working with torch_xla |
cc @vanbasten23 to review the code |
Hey @dshalem , I took a look. It fails to run your kernel using JAX:
The error is
You may want to make sure the kernel runs fine using JAX. Moreover, in your original script, you need to do:
before importing JAX. Otherwise, your job may hang even if your kernel runs fine. Please find the example at https://github.com/pytorch/xla/blob/7dd2697af2fea3b12e8484d2f66b72b008a8cc3d/docs/source/features/pallas.md#custom-kernels-via-pallas. |
🐛 Bug
I am trying to write a custom Pallas kernel to use it in TPU. I am using blocking method to keep my kernel from going OOM. However, when I am using grids, it seems that I get kernel problems with the expected output and input shapes. It seems that the chunking / splitting of the input does not perform as expected. I checked that my code indeed has the right shapes, grid and indexing method. However, the kernel itself is getting wrong input.
I think it may be bug in how the TPU is handling the chunking in pallas kernels, but I am not sure. Any help would be appreciated!
To Reproduce
I am attaching here tests for replication. You can see that only the tests with original input tensors larger than block size fails.
My Kernel Code
Debug Output And Stack Trace
RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Failed to verify layout for Mosaic kernel operand 1: XLA layout does not match MLIR layout for an operand of shape f32[192]: expected {0:T(128)}, got {0:T(256)}
But printing the values I provide to the pallas_call
Out shape: ShapeDtypeStruct(shape=(192,), dtype=float32), grid: 2, block_shape: (128,)
Expected behavior
The tests should not fail. When run on GPU they all pass.
Environment
Additional context
The tests for easy replication:
test_pallas_tiling.py.zip
The text was updated successfully, but these errors were encountered: