Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reshape operation with Pallas on TPU/Mosaic cause XlaRuntimeError: INTERNAL #21999

Closed
Lime-Cakes opened this issue Jun 20, 2024 · 2 comments
Closed
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@Lime-Cakes
Copy link

Description

Reshape operation of [1,11,11,1,128] to [121,128] fails within pallas kernel. I've tested with reshaping [11,11,128] to [121,128] instead and it still fails with same error (as shown below). This is done while on 0.4.28. Testing on 0.4.30 caused a different error.

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast

The MLIR operation involved:
  %13 = "vector.shape_cast"(%11) : (vector<1x11x11x1x128xf32>) -> vector<121x128xf32>

Full error.

  File "/kaggle/working/basic_dit.py", line 91, in neigh_attn_pall_2d_batched_single_head
    return pl.pallas_call(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 304, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 181, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 2789, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1525, in _pjit_call_impl
    return xc._xla.pjit(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1508, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1438, in _pjit_call_impl_python
    inline=inline, lowering_parameters=mlir.LoweringParameters()).compile()
  File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2363, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2860, in from_hlo
    xla_executable = _cached_compilation(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2678, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 330, in compile_or_get_cached
    return _compile_and_write_cache(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 501, in _compile_and_write_cache
    executable = backend_compile(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 237, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast

The MLIR operation involved:
  %13 = "vector.shape_cast"(%11) : (vector<1x11x11x1x128xf32>) -> vector<121x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

Error on 0.4.30
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Mosaic kernel operand 5 has an unexpected layout: expected {4,3,2,1,0:T(2,128)}, got {4,3,2,1,0:T(8,128)}

Full error on 0.4.30.

  File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 327, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 185, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 2834, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 921, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1635, in _pjit_call_impl
    return xc._xla.pjit(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1614, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1544, in _pjit_call_impl_python
    ).compile(compile_options)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2496, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2995, in from_hlo
    xla_executable = _cached_compilation(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2810, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 378, in compile_or_get_cached
    return _compile_and_write_cache(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 608, in _compile_and_write_cache
    executable = backend_compile(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 238, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Mosaic kernel operand 5 has an unexpected layout: expected {4,3,2,1,0:T(2,128)}, got {4,3,2,1,0:T(8,128)}

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.13 (main, Mar 12 2024, 12:16:25) [GCC 12.2.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='d0909048617f', release='6.1.42+', version='#1 SMP PREEMPT_DYNAMIC Sun Oct  8 14:23:56 UTC 2023', machine='x86_64')
jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.10.13 (main, Mar 12 2024, 12:16:25) [GCC 12.2.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='d0909048617f', release='6.1.42+', version='#1 SMP PREEMPT_DYNAMIC Sun Oct  8 14:23:56 UTC 2023', machine='x86_64')
@Lime-Cakes Lime-Cakes added the bug Something isn't working label Jun 20, 2024
@justinjfu justinjfu added the pallas Issues pertaining to Pallas (GPU or TPU) label Jun 20, 2024
@justinjfu
Copy link
Collaborator

#22015 may be able to address this issue (which allows you to manually request a layout change to an array). Either that, or you may have to specify your block size to be (8, 128), assuming your block size is some other shape such as (2, 128).

@Lime-Cakes
Copy link
Author

I realized I've misunderstood the docs. Thanks! Just to be sure, it means window size must be (...,8,128) or the last two dimension being multiple of 8,128. Other shapes are not supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

2 participants