You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
((ShapedArray(float32[1,1]), ShapedArray(int32[10,1,0])), frozenset())
BCOO(float32[10], nse=1, n_batch=1)
Traceback (most recent call last):
File "/nix/store/g2zlb8pq0rd2kyz7q6v1l2ivz96am8lp-env/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 994, in lower_jaxpr_to_module
if not ctx.module.operation.verify():
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: unknown: type of return operand 0 ('tensor<10x1xf32>') doesn't match function result type ('tensor<1x1xf32>') in function @main
note: unknown: see current operation: "func.return"(%104, %106) : (tensor<10x1xf32>, tensor<10x1x0xi32>) -> ()
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "***********************************.py", line 42, in <module>
print(jax.jit(sp.sparsify(inner))(lhs, rhs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Cannot lower jaxpr with verifier errors:
type of return operand 0 ('tensor<10x1xf32>') doesn't match function result type ('tensor<1x1xf32>') in function @main
at loc(unknown)
see current operation: "func.return"(%104, %106) : (tensor<10x1xf32>, tensor<10x1x0xi32>) -> ()
at loc(unknown)Define JAX_DUMP_IR_TO to dump the module.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.12.2 (main, Feb 6 2024, 20:19:44) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='DESKTOP-NV274K6', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')
$ nvidia-smi
Mon Jun 17 13:59:12 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.40.06 Driver Version: 551.23 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce GTX 1650 On | 00000000:01:00.0 Off | N/A |
| N/A 60C P0 15W / 50W | 202MiB / 4096MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 98662 C /python3.12 N/A |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
results in
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: