Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.xla_computation as_hlo_text() #21990

Open
aoshujiaocheng opened this issue Jun 20, 2024 · 2 comments
Open

jax.xla_computation as_hlo_text() #21990

aoshujiaocheng opened this issue Jun 20, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@aoshujiaocheng
Copy link

import jax

def f(x): return jax.numpy.sin(jax.numpy.cos(x))
c = jax.xla_computation(f)(3.)
print(c.as_hlo_text())

How to convert c.as_hlo_text() into c?

@aoshujiaocheng aoshujiaocheng added the enhancement New feature or request label Jun 20, 2024
@justinjfu
Copy link
Collaborator

JAX doesn't have python bindings for parsing HLO strings, however the machinery exists in XLA/C++. If you just want to run the function, you could perhaps try running the module using XLA's infrastructure (run_hlo_module - https://openxla.org/xla/tools).

Also as a heads up - jax.xla_computation will be deprecated with the upcoming 0.4.30 release.

@yashk2810
Copy link
Member

Note 0.4.30 is released and jax.xla_computation is deprecated. Please use jax.jit(f).lower(3.).compiler_ir('hlo') as the replacement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants