Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxlib wheel hardcoded to manylinux2014 platform #22034

Open
mrodden opened this issue Jun 21, 2024 · 3 comments
Open

jaxlib wheel hardcoded to manylinux2014 platform #22034

mrodden opened this issue Jun 21, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@mrodden
Copy link

mrodden commented Jun 21, 2024

Description

I am opening this issue to get a discussion started around manylinux support for JAX/jaxlib. I am currently working on trying to get manylinux_2_28 jaxlib wheels built with ROCm support for ROCm users.

Currently jaxlib wheels are hardcoded to be tagged as manylinux2014. See https://github.com/google/jax/blob/main/jaxlib/tools/build_wheel.py#L167 and https://github.com/google/jax/blob/main/jax/tools/build_utils.py#L56-L58
This hardcoding of platform name forces the wheel build process to tag the wheel filename and metadata (WHEEL file inside the zip) as manylinux2014 which is only correct if building/linking on a platform with glibc 2.17, which is rather old.
I am not aware of how google is building these wheels internally, but I imagine this is going to result in an incorrect wheel build for almost all other linux users.

I have tested a change which removes the platform name override, running a build on a Centos8 / Almalinux 8 system, and it results in a linux_x86_64 wheel, which is expected. After running auditwheel repair on this wheel it correctly scans the glibc versioned symbols and creates a new manylinux_2_27 wheel (highest symbol is actually 2.27 in the output wheel). This seems more correct, although I have yet to try it with the new jax plugin wheel builds.

Some open questions are:
Is there a reason that the builds were hardcoded to manylinux2014?

Is making the above change and using auditwheel repair a path forward that works for upstream JAX/jaxlib builds?

What is the manylinux target that jaxlib needs to support going forward? manylinux2014 seems to be deprecated at the end of June 2024
from pypa/manylinux repo

(PEP 599 defines the following platform tags: manylinux2014_x86_64, manylinux2014_i686, manylinux2014_aarch64, manylinux2014_armv7l, manylinux2014_ppc64, manylinux2014_ppc64le and manylinux2014_s390x. Wheels are built on CentOS 7 which will reach End of Life (EOL) on June 30th, 2024.)

System info (python version, jaxlib version, accelerator, etc.)

@mrodden mrodden added the bug Something isn't working label Jun 21, 2024
@hawkinsp
Copy link
Member

In general our build script produces the tags that our release builds should produce; it doesn't try to adapt the tags to the environment.

I'm a little reluctant to run auditwheel repair because I don't want auditwheel changing the wheel.

One option would be for us to default to a relaxed tag like linux_x86_64, and change our release builds to override the tag to a manylinux tag.

@hawkinsp
Copy link
Member

As to ongoing support: we're not sure yet. One likely possibility is at least 2_28, but we may go newer. (The thing pushing us to newer is that we would like to build with C++20).

@mrodden
Copy link
Author

mrodden commented Jun 24, 2024

One option would be for us to default to a relaxed tag like linux_x86_64, and change our release builds to override the tag to a manylinux tag.

I think this might be easiest, because then it pushes the "generalization" step to the release process that comes after build.

I am also a bit worried about auditwheel repair changing the wheel as well, since it tries to embed any third party SO files into the zip and changes the RPATHs to use them instead of any external ones. This can be worked around with its --exclude options, but it seems too easy to miss something. My current plan is to probably write something that uses auditwheel's functionality to modify the wheel but do it in a more controlled manner. Basically, just doing the lddtree scan for all versioned symbols, collecting the max versions of libs, and then outputting the manylinux_x_y wheel, without any RPATH or SO file changes.

One likely possibility is at least 2_28, but we may go newer. (The thing pushing us to newer is that we would like to build with C++20).

I am positive this would make a lot of folks that use JAX on older systems very upset... all the users I am supporting right now seem to be on ubuntu 20 or EL8 derivatives which require glibc 2.28 or less. I haven't looked at the glibcpp versions on those distros but I am almost certain they are not C++20 level universally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants