Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse: bcoo_spdot_general have incorrect abstract_eval #21921

Closed
Alan-Chen99 opened this issue Jun 17, 2024 · 1 comment · Fixed by #22093
Closed

sparse: bcoo_spdot_general have incorrect abstract_eval #21921

Alan-Chen99 opened this issue Jun 17, 2024 · 1 comment · Fixed by #22093
Assignees
Labels
bug Something isn't working

Comments

@Alan-Chen99
Copy link

Description

import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import sparse as sp
from jax.experimental.sparse.util import SparseInfo

lhs = sp.BCOO(
    (
        jnp.ones((1, 1)),
        lax.broadcasted_iota(jnp.int32, (10, 1, 1), 0),
    ),
    shape=(10, 10),
)

rhs = sp.BCOO(
    (jnp.array([1.0]), jnp.array([[3]])),
    shape=(10,),
)

# return shape is wrong
print(
    sp.bcoo_spdot_general_p.abstract_eval(
        lhs.data.aval,
        lhs.indices.aval,
        rhs.data.aval,
        rhs.indices.aval,
        dimension_numbers=(((1,), (0,)), ((), ())),
        lhs_spinfo=SparseInfo(
            shape=(10, 10), indices_sorted=False, unique_indices=False
        ),
        preferred_element_type=None,
        rhs_spinfo=SparseInfo(shape=(10,), indices_sorted=False, unique_indices=False),
    )
)


def inner(x, y):
    return lax.dot_general(x, y, dimension_numbers=(((1,), (0,)), ((), ())))


print(sp.sparsify(inner)(lhs, rhs))  # this works fine
print(jax.jit(sp.sparsify(inner))(lhs, rhs))

results in

((ShapedArray(float32[1,1]), ShapedArray(int32[10,1,0])), frozenset())
BCOO(float32[10], nse=1, n_batch=1)
Traceback (most recent call last):
  File "/nix/store/g2zlb8pq0rd2kyz7q6v1l2ivz96am8lp-env/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 994, in lower_jaxpr_to_module
    if not ctx.module.operation.verify():
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: unknown: type of return operand 0 ('tensor<10x1xf32>') doesn't match function result type ('tensor<1x1xf32>') in function @main
 note: unknown: see current operation: "func.return"(%104, %106) : (tensor<10x1xf32>, tensor<10x1x0xi32>) -> ()

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "***********************************.py", line 42, in <module>
    print(jax.jit(sp.sparsify(inner))(lhs, rhs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Cannot lower jaxpr with verifier errors:
	type of return operand 0 ('tensor<10x1xf32>') doesn't match function result type ('tensor<1x1xf32>') in function @main
		at loc(unknown)
	see current operation: "func.return"(%104, %106) : (tensor<10x1xf32>, tensor<10x1x0xi32>) -> ()
		at loc(unknown)Define JAX_DUMP_IR_TO to dump the module.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.12.2 (main, Feb  6 2024, 20:19:44) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='DESKTOP-NV274K6', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')


$ nvidia-smi
Mon Jun 17 13:59:12 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.40.06              Driver Version: 551.23         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce GTX 1650        On  |   00000000:01:00.0 Off |                  N/A |
| N/A   60C    P0             15W /   50W |     202MiB /   4096MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     98662      C   /python3.12                                 N/A      |
+-----------------------------------------------------------------------------------------+
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 25, 2024

Thanks for the report! #22093 should fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants