-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is there a way to control the pjit to dynamically partition the input IDs into different ranks? #21902
Comments
Hi @MoFHeka, could you clarify how your question relates to JAX? JAX does not have a concept of a rank (other than array rank), so |
@superbobry Sorry my mistake. A simpler description might be, when I use pmap, how do I implement the all2allv operator in transform code? Among them, the data parallel dimension (batch) requires dynamic size data exchange. For example, GPU0 and GPU1 run the data parallelism training. GPU0 get [0,1,2,2], also GPU1 get [0,3,2,1]. But hash table in GPU0 only stored the key 0,2,..., and GPU1 with 1,3,.... For now there will be an alltoallv operator, where GPU0 send [1] and receive [0,2], and GPU1 send [0,2] and receive [1]. And then lookup the value back to their origin GPUx. |
Try using shard_map? https://jax.readthedocs.io/en/latest/notebooks/shard_map.html |
@yashk2810 Thanks for your answer, but I don't know if shard_map supports swapping tensor with different dynamic shapes in batch dimensions between different GPU. And also I need to transform the collective operator(alltoallv) into send/recv(Peer to Peer) expressions to achieve the acceleration that can be overlapped by the calculation operator. |
I will let Yash comment on dynamic shapes in Re "transform the collective": I don't think we have send/recv operations in lax, so I'm not quite sure what kind of transform you are talking about. Are these custom JAX operations you implemented? |
@superbobry Assume I have JAX lookup table operations. Dynamic shapes are crucial in recommendation algorithms, and we now need to explore large generative recommendation model so jax or torch are needed. |
@yashk2810 @superbobry I'm sorry to bother you, but I found that jax.lax.ppermute needs a tensor with the same amount of data for each sharding, while I need a tensor with the different amount of data for each sharding. |
I will defer the answer to Yash, as I'm not yet very familiar with sharding APIs in JAX. |
I need to build a model like this:
There is a very large distributed dynamic shape Embedding, which can be seen as a hash table.
In every DP rank, when workers get the IDs input they need to transfer the IDs to another worker for lookup table.
For example, rank 0 get [0,1,2,2], also rank 1 get [0,3,2,1]. But hash table in rank 0 only stored the key 0,2,..., and rank 1 with 1,3,.... For now there will be a alltoallv operator, where rank 0 send [1] and receive [0,2], and rank 1 send [0,2] and receive [1]. And then lookup the value back to their origin rank.
Note: In fact, this procedure uses the send_recv operator to implement asynchronous ring alltoall to overlap the table lookup time.
Please:
The text was updated successfully, but these errors were encountered: