Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a way to control the pjit to dynamically partition the input IDs into different ranks? #21902

Open
2 tasks done
MoFHeka opened this issue Jun 16, 2024 · 8 comments
Open
2 tasks done
Labels
needs info More information is required to diagnose & prioritize the issue. question Questions for the JAX team

Comments

@MoFHeka
Copy link

MoFHeka commented Jun 16, 2024

I need to build a model like this:
There is a very large distributed dynamic shape Embedding, which can be seen as a hash table.
In every DP rank, when workers get the IDs input they need to transfer the IDs to another worker for lookup table.
For example, rank 0 get [0,1,2,2], also rank 1 get [0,3,2,1]. But hash table in rank 0 only stored the key 0,2,..., and rank 1 with 1,3,.... For now there will be a alltoallv operator, where rank 0 send [1] and receive [0,2], and rank 1 send [0,2] and receive [1]. And then lookup the value back to their origin rank.
9bdc0d9d711b0980579cdc0eac52f8ed564322

Note: In fact, this procedure uses the send_recv operator to implement asynchronous ring alltoall to overlap the table lookup time.

Please:

  • Check for duplicate requests.
  • Describe your goal, and if possible provide a code snippet with a motivating example.
@MoFHeka MoFHeka added the enhancement New feature or request label Jun 16, 2024
@superbobry
Copy link
Member

superbobry commented Jun 17, 2024

Hi @MoFHeka, could you clarify how your question relates to JAX? JAX does not have a concept of a rank (other than array rank), so jit probably cannot do the partitioning you are describing.

@superbobry superbobry added needs info More information is required to diagnose & prioritize the issue. question Questions for the JAX team and removed enhancement New feature or request labels Jun 17, 2024
@MoFHeka
Copy link
Author

MoFHeka commented Jun 17, 2024

@superbobry Sorry my mistake. A simpler description might be, when I use pmap, how do I implement the all2allv operator in transform code? Among them, the data parallel dimension (batch) requires dynamic size data exchange.

For example, GPU0 and GPU1 run the data parallelism training. GPU0 get [0,1,2,2], also GPU1 get [0,3,2,1]. But hash table in GPU0 only stored the key 0,2,..., and GPU1 with 1,3,.... For now there will be an alltoallv operator, where GPU0 send [1] and receive [0,2], and GPU1 send [0,2] and receive [1]. And then lookup the value back to their origin GPUx.

@yashk2810
Copy link
Member

@MoFHeka
Copy link
Author

MoFHeka commented Jun 17, 2024

Try using shard_map? https://jax.readthedocs.io/en/latest/notebooks/shard_map.html

@yashk2810 Thanks for your answer, but I don't know if shard_map supports swapping tensor with different dynamic shapes in batch dimensions between different GPU. And also I need to transform the collective operator(alltoallv) into send/recv(Peer to Peer) expressions to achieve the acceleration that can be overlapped by the calculation operator.

@superbobry
Copy link
Member

I will let Yash comment on dynamic shapes in shard_map, but I suspect that this use-case is not supported (since most of JAX assumes static shapes atm).

Re "transform the collective": I don't think we have send/recv operations in lax, so I'm not quite sure what kind of transform you are talking about. Are these custom JAX operations you implemented?

@MoFHeka
Copy link
Author

MoFHeka commented Jun 18, 2024

@superbobry Assume I have JAX lookup table operations.
I need to split the Ring AllToAll collective sync operator into several parts to ensure that while some data is being transferred, others are querying the hash-table to overlap the transmission overhead time.

Dynamic shapes are crucial in recommendation algorithms, and we now need to explore large generative recommendation model so jax or torch are needed.

Ring_visualization_all_to_all

@MoFHeka
Copy link
Author

MoFHeka commented Jun 25, 2024

@yashk2810 @superbobry I'm sorry to bother you, but I found that jax.lax.ppermute needs a tensor with the same amount of data for each sharding, while I need a tensor with the different amount of data for each sharding.

@superbobry
Copy link
Member

I will defer the answer to Yash, as I'm not yet very familiar with sharding APIs in JAX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs info More information is required to diagnose & prioritize the issue. question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants