You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Reshape operation of [1,11,11,1,128] to [121,128] fails within pallas kernel. I've tested with reshaping [11,11,128] to [121,128] instead and it still fails with same error (as shown below). This is done while on 0.4.28. Testing on 0.4.30 caused a different error.
File "/kaggle/working/basic_dit.py", line 91, in neigh_attn_pall_2d_batched_single_head
return pl.pallas_call(
File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 304, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 181, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 2789, in bind
return self.bind_with_trace(top_trace, args, params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1525, in _pjit_call_impl
return xc._xla.pjit(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1508, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1438, in _pjit_call_impl_python
inline=inline, lowering_parameters=mlir.LoweringParameters()).compile()
File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2363, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2860, in from_hlo
xla_executable = _cached_compilation(
File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2678, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 330, in compile_or_get_cached
return _compile_and_write_cache(
File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 501, in _compile_and_write_cache
executable = backend_compile(
File "/usr/local/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 237, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast
The MLIR operation involved:
%13 = "vector.shape_cast"(%11) : (vector<1x11x11x1x128xf32>) -> vector<121x128xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
Error on 0.4.30 jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Mosaic kernel operand 5 has an unexpected layout: expected {4,3,2,1,0:T(2,128)}, got {4,3,2,1,0:T(8,128)}
Full error on 0.4.30.
File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 327, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 185, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 2834, in bind
return self.bind_with_trace(top_trace, args, params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 921, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1635, in _pjit_call_impl
return xc._xla.pjit(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1614, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1544, in _pjit_call_impl_python
).compile(compile_options)
File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2496, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2995, in from_hlo
xla_executable = _cached_compilation(
File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2810, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 378, in compile_or_get_cached
return _compile_and_write_cache(
File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 608, in _compile_and_write_cache
executable = backend_compile(
File "/usr/local/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 238, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Mosaic kernel operand 5 has an unexpected layout: expected {4,3,2,1,0:T(2,128)}, got {4,3,2,1,0:T(8,128)}
System info (python version, jaxlib version, accelerator, etc.)
#22015 may be able to address this issue (which allows you to manually request a layout change to an array). Either that, or you may have to specify your block size to be (8, 128), assuming your block size is some other shape such as (2, 128).
I realized I've misunderstood the docs. Thanks! Just to be sure, it means window size must be (...,8,128) or the last two dimension being multiple of 8,128. Other shapes are not supported.
Description
Reshape operation of [1,11,11,1,128] to [121,128] fails within pallas kernel. I've tested with reshaping [11,11,128] to [121,128] instead and it still fails with same error (as shown below). This is done while on 0.4.28. Testing on 0.4.30 caused a different error.
Full error.
Error on 0.4.30
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Mosaic kernel operand 5 has an unexpected layout: expected {4,3,2,1,0:T(2,128)}, got {4,3,2,1,0:T(8,128)}
Full error on 0.4.30.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: