Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Allow deepspeed.comm.inference_all_reduce in torch.compile graph #5604

Open
wants to merge 20 commits into
base: master
Choose a base branch
from

Conversation

delock
Copy link
Contributor

@delock delock commented Jun 3, 2024

This PR allows deepspeed.comm.inference_all_reduce() enters torch.compile graph even it is implemented as C++ kernel in DeepSpeed.

Previous implementation register inference_all_reduce() C++ kernel as pybind function so it can be called inside PyThon code. However pybind function cannot be recognized by PyTorch so graph breaks when inference_all_reduce is called.

We address issue by register inference_all_reduce as a PyTorch custom op torch.ops.deepspeed.inference_all_reduce, so it can be built into PyTorch graph

The output trace code from torchinductor

class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[5, 4]", primals_2: "f32[5]", primals_3: "f32[4, 4]"):
        # File: /home/gma/DeepSpeed/deepspeed/comm/torch.py:161 in inference_all_reduce, code: return torch.ops.deepspeed.inference_all_reduce_(tensor)
        inference_all_reduce: "f32[4, 4]" = torch.ops.deepspeed.inference_all_reduce.default(primals_3)

        # File: /home/gma/allreduce_graph/test_allreduce.py:33 in forward, code: return self.linear(input)
        permute: "f32[4, 5]" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        addmm: "f32[4, 5]" = torch.ops.aten.addmm.default(primals_2, inference_all_reduce, permute);  primals_2 = permute = None

        # No stacktrace found for following nodes
        copy_: "f32[4, 4]" = torch.ops.aten.copy_.default(primals_3, inference_all_reduce);  primals_3 = None
        return [addmm, inference_all_reduce]

Note in this PR the inference_all_reduce op for CPU does not handle multinode and FP16 data type. For FP16 data type support, we will align with PyTorch CPU FP16 plan. For multinode, we are still looking at the possibility to upstream oneCCL integration into PyTorch, so we are able to get use of oneCCL for multinode tensor parallel inference with PyTorch.

This PR is independent to #5571. They can work seperately or together without issue.

@delock delock marked this pull request as draft June 3, 2024 08:28
@delock delock marked this pull request as ready for review June 6, 2024 05:30
@tjruwase tjruwase requested review from tohtana and umchand and removed request for arashb, awan-10 and mrwyattii June 21, 2024 22:23
@tohtana
Copy link
Contributor

tohtana commented Jun 21, 2024

@delock Thank you for the great PR. I didn't know we can avoid some graph breaks by registering C++ extension op as torch's operator. This approach will definitely be useful in many features in DeepSpeed.
Let's merge it after it passe all the tests.

@delock
Copy link
Contributor Author

delock commented Jun 22, 2024

@delock Thank you for the great PR. I didn't know we can avoid some graph breaks by registering C++ extension op as torch's operator. This approach will definitely be useful in many features in DeepSpeed. Let's merge it after it passe all the tests.

Hi @tohtana, formatting is fixed. The other error is an HF hub connection issue. Should pass on rerun.

@tohtana tohtana enabled auto-merge June 24, 2024 20:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants