Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce cpu host overhead when using moe #5578

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ranzhejiang
Copy link

@ranzhejiang ranzhejiang commented May 29, 2024

The operation .to('cpu') is not necessary for exp_counts, and it will cause device to host synchronization which damage performance.

@ranzhejiang ranzhejiang requested a review from awan-10 as a code owner May 29, 2024 04:01
@loadams loadams requested a review from tohtana May 31, 2024 22:15
@@ -366,7 +366,7 @@ def top2gating(logits: Tensor,
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')
return l_aux, combine_weights, dispatch_mask, exp_counts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently exp_counts is unused at forward() of any of MoE classes, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently exp_counts is unused at forward() of any of MoE classes, right?

Yes, I have test it in Megatron-deepspeed and find that exp_counts is unused at forward() of any of MoE classes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have test these changes on my testing GPU platform, and work fine, no error and the loss keeps same to the original way.

Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ranzhejiang Thank you for your contribution! I have a few questions about your changes. Can you clarify them?

@@ -322,7 +322,7 @@ def top2gating(logits: Tensor,
l_aux = torch.mean(me * ce) * num_experts * num_experts

# gating decisions
exp_counts = torch.sum(mask1 + mask2, dim=0)
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the device of mask1 and mask1 be different from logits?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the device of mask1 and mask1 be different from logits?

From line 296 to 301, we can find that , the calculation of mask1 depends on logits, and all torch operation will keep the original device,  so the device of mask1 and logits be the same one. The same to mask1 and mask2 line 309 to 311

@ranzhejiang ranzhejiang force-pushed the zhejiang/reduce_host_overhead_moe branch from e9e32f4 to d860d2c Compare June 11, 2024 03:32
@ranzhejiang
Copy link
Author

Hi, @tohtana I have clarified the modifications you mentioned and retest this PR with Megatron-Deepspeed on GPU platform(8xA800). It runs well and loss remains consistent with the original method, Could you please help review it again? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants