Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas Kernel Expected Output Shape Error Using Grids #8469

Open
dshalem opened this issue Dec 8, 2024 · 6 comments
Open

Pallas Kernel Expected Output Shape Error Using Grids #8469

dshalem opened this issue Dec 8, 2024 · 6 comments

Comments

@dshalem
Copy link

dshalem commented Dec 8, 2024

🐛 Bug

I am trying to write a custom Pallas kernel to use it in TPU. I am using blocking method to keep my kernel from going OOM. However, when I am using grids, it seems that I get kernel problems with the expected output and input shapes. It seems that the chunking / splitting of the input does not perform as expected. I checked that my code indeed has the right shapes, grid and indexing method. However, the kernel itself is getting wrong input.

I think it may be bug in how the TPU is handling the chunking in pallas kernels, but I am not sure. Any help would be appreciated!

To Reproduce

I am attaching here tests for replication. You can see that only the tests with original input tensors larger than block size fails.

My Kernel Code

  @jax.jit
  def round_down_and_up(x: jax.Array) -> (jax.Array, jax.Array):
      """
      Simplified wrapper for kernel execution, treating x as a vector.
      Handles explicit padding and alignment with TPU tiling constraints.
      """
      block_size = 128
      # padded_length = (original_length + block_size - 1) // block_size * block_size
      #
      # # Explicitly pad the input tensor
      # if original_length != padded_length:
      #     x = jnp.pad(x, (0, padded_length - original_length), mode="constant", constant_values=0)
  
      # Define block shape and grid
      block_shape = (128,)  # TPU requires blocks divisible by 128 for f32
      grid = (len(x) + block_size - 1) // block_shape[0]
  
      # Define BlockSpec
      block_spec = pl.BlockSpec(block_shape=block_shape, index_map=lambda i: (i,))
  
      # Define output shape
      out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  
      # # Debugging: Verify padded shape, block shape, and grid
      # jax.debug.print("Input Length: {input_length}, Padded Length: {padded_length}, "
      #                 "BlockShape: {block_shape}, Grid: {grid}",
      #                 input_length=original_length, padded_length=padded_length,
      #                 block_shape=block_shape, grid=grid)
  
      print(f"Out shape: {out_shape}, grid: {grid}, block_shape: {block_shape}")
      # Call the kernel
      x_low, x_high = pl.pallas_call(
          round_down_and_up_bfloat16_kernel,
          out_shape=(out_shape, out_shape),
          grid=(grid,),
          in_specs=(block_spec,),  # Input tiling specification
          out_specs=(block_spec, block_spec),  # Output tiling specification
      )(x)
  
      return x_low, x_high

Debug Output And Stack Trace

RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Failed to verify layout for Mosaic kernel operand 1: XLA layout does not match MLIR layout for an operand of shape f32[192]: expected {0:T(128)}, got {0:T(256)}

But printing the values I provide to the pallas_call
Out shape: ShapeDtypeStruct(shape=(192,), dtype=float32), grid: 2, block_shape: (128,)

Expected behavior

The tests should not fail. When run on GPU they all pass.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 2.4

Additional context

The tests for easy replication:
test_pallas_tiling.py.zip

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 9, 2024

for pallas bug you will have better luck opening an issue under jax repo, since they own the pallas.

@miladm
Copy link
Collaborator

miladm commented Dec 9, 2024

+1

I see you are using the 2.4 whl. Do you observe the same error pattern with torch_xla nightly (or 2.5)? More recent torch_xla whls use more recent libtpu and jax dependenceis.

@dshalem
Copy link
Author

dshalem commented Dec 10, 2024

+1

I see you are using the 2.4 whl. Do you observe the same error pattern with torch_xla nightly (or 2.5)? More recent torch_xla whls use more recent libtpu and jax dependenceis.

I will Try today with 2.5 :)

Update:
Not working with torch_xla 2.5 as well

torch 2.5.1
torch-xla 2.5.1

@dshalem
Copy link
Author

dshalem commented Dec 10, 2024

for pallas bug you will have better luck opening an issue under jax repo, since they own the pallas.

But I only see this bug while working with torch_xla

@miladm
Copy link
Collaborator

miladm commented Dec 10, 2024

cc @vanbasten23 to review the code

@vanbasten23
Copy link
Collaborator

Hey @dshalem , I took a look. It fails to run your kernel using JAX:

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import pallas as pl


@jax.jit
def round_down_and_up(x: jax.Array) -> (jax.Array, jax.Array):
    """
    Simplified wrapper for kernel execution, treating x as a vector.
    Handles explicit padding and alignment with TPU tiling constraints.
    """
    block_size = 128
    # padded_length = (original_length + block_size - 1) // block_size * block_size
    #
    # # Explicitly pad the input tensor
    # if original_length != padded_length:
    #     x = jnp.pad(x, (0, padded_length - original_length), mode="constant", constant_values=0)

    # Define block shape and grid
    block_shape = (128,)  # TPU requires blocks divisible by 128 for f32
    grid = (len(x) + block_size - 1) // block_shape[0]

    # Define BlockSpec
    block_spec = pl.BlockSpec(block_shape=block_shape, index_map=lambda i: (i,))

    # Define output shape
    out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)

    # # Debugging: Verify padded shape, block shape, and grid
    # jax.debug.print("Input Length: {input_length}, Padded Length: {padded_length}, "
    #                 "BlockShape: {block_shape}, Grid: {grid}",
    #                 input_length=original_length, padded_length=padded_length,
    #                 block_shape=block_shape, grid=grid)

    print(f"Out shape: {out_shape}, grid: {grid}, block_shape: {block_shape}")
    # Call the kernel
    x_low, x_high = pl.pallas_call(
        round_down_and_up_bfloat16_kernel,
        out_shape=(out_shape, out_shape),
        grid=(grid,),
        in_specs=(block_spec,),  # Input tiling specification
        out_specs=(block_spec, block_spec),  # Output tiling specification
    )(x)

    return x_low, x_high

def round_down_and_up_bfloat16_kernel(x_ref, o_ref_down, o_ref_up):
    # # Define constants for bitwise operations
    mask = jnp.array(0xFFFF0000, dtype=jnp.uint32).astype(jnp.int32)
    increment = jnp.array(0x10000, dtype=jnp.uint32).astype(jnp.int32)

    # Perform operations on the current tile
    x_bits = x_ref[...].view(jnp.int32)  # Treat float32 as int32
    bf16_towards_zero = x_bits & mask  # Rounded down
    bf16_next = bf16_towards_zero + increment  # Rounded up in one step

    # Write results back to output references
    o_ref_down[...] = bf16_towards_zero.view(jnp.float32)  # Convert back to float32
    o_ref_up[...] = bf16_next.view(jnp.float32)  # Convert back to float32


if __name__ == "__main__":
    #test_pt_wrapper_large_input()
    # test_pt_wrapper_basic_functionality()
    blk=128
    x = jnp.arange(blk)
    out = round_down_and_up(x)
    print(out)

The error is

$ python test_pallas_tiling_copy.py 
Out shape: ShapeDtypeStruct(shape=(128,), dtype=int32), grid: 1, block_shape: (128,)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/disks/persist/test_pallas_tiling_copy.py", line 69, in <module>
    out = round_down_and_up(x)
  File "/mnt/disks/persist/test_pallas_tiling_copy.py", line 38, in round_down_and_up
    x_low, x_high = pl.pallas_call(
  File "/home/xiowei/anaconda3/envs/myvllmenv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 1844, in wrapped
    jaxpr, consts = _trace_kernel_to_jaxpr(
  File "/home/xiowei/anaconda3/envs/myvllmenv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 1416, in _trace_kernel_to_jaxpr
    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
  File "/home/xiowei/anaconda3/envs/myvllmenv/lib/python3.10/site-packages/jax/_src/pallas/primitives.py", line 833, in wrap_with_transforms
    return f(*new_args)
  File "/mnt/disks/persist/test_pallas_tiling_copy.py", line 60, in round_down_and_up_bfloat16_kernel
    o_ref_down[...] = bf16_towards_zero.view(jnp.float32)  # Convert back to float32
  File "/home/xiowei/anaconda3/envs/myvllmenv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 1049, in op
    return getattr(self.aval, f"_{name}")(self, *args)
ValueError: Invalid dtype for `swap`. Ref dtype: int32. Value dtype: float32. 

You may want to make sure the kernel runs fine using JAX.

Moreover, in your original script, you need to do:

from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()

before importing JAX. Otherwise, your job may hang even if your kernel runs fine. Please find the example at https://github.com/pytorch/xla/blob/7dd2697af2fea3b12e8484d2f66b72b008a8cc3d/docs/source/features/pallas.md#custom-kernels-via-pallas.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants