-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] Allow deepspeed.comm.inference_all_reduce in torch.compile graph #5604
base: master
Are you sure you want to change the base?
[CPU] Allow deepspeed.comm.inference_all_reduce in torch.compile graph #5604
Conversation
@delock Thank you for the great PR. I didn't know we can avoid some graph breaks by registering C++ extension op as torch's operator. This approach will definitely be useful in many features in DeepSpeed. |
Hi @tohtana, formatting is fixed. The other error is an HF hub connection issue. Should pass on rerun. |
This PR allows
deepspeed.comm.inference_all_reduce()
enters torch.compile graph even it is implemented as C++ kernel in DeepSpeed.Previous implementation register
inference_all_reduce()
C++ kernel as pybind function so it can be called inside PyThon code. However pybind function cannot be recognized by PyTorch so graph breaks wheninference_all_reduce
is called.We address issue by register
inference_all_reduce
as a PyTorch custom optorch.ops.deepspeed.inference_all_reduce
, so it can be built into PyTorch graphThe output trace code from torchinductor
Note in this PR the inference_all_reduce op for CPU does not handle multinode and FP16 data type. For FP16 data type support, we will align with PyTorch CPU FP16 plan. For multinode, we are still looking at the possibility to upstream oneCCL integration into PyTorch, so we are able to get use of oneCCL for multinode tensor parallel inference with PyTorch.
This PR is independent to #5571. They can work seperately or together without issue.