Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add and Remove ZeRO 3 Hooks #5658

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Add and Remove ZeRO 3 Hooks #5658

wants to merge 6 commits into from

Conversation

jomayeri
Copy link
Contributor

Gives the ability to add and remove the forward hooks in ZeRO 3 by using a context manager. These code changes were taken from a Huggingface PR and integrated for direct support in DeepSpeed.

This is useful in the inference case and the speedup can be observed here.

@tjruwase
Copy link
Contributor

@jomayeri, please add some unit tests

@tjruwase tjruwase requested review from tohtana and removed request for mrwyattii June 18, 2024 00:25


@contextmanager
def unwrap_model_for_generation(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to use general naming that describes what the utility does rather than specific usage like generation.

Suggested change
def unwrap_model_for_generation(model):
def unshard_and_remove_hooks(model):

if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this zero-3 case where we have hooks attached to the optimizer?

elif model.optimizer is not None:
optimizer_offload = model.optimizer

for hook in optimizer_offload.forward_hooks:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hooks are associated with parameters not optimizer, so this naming is a bit confusing. Let's clarify this.

Suggested change
for hook in optimizer_offload.forward_hooks:
for hook in parameter_offload.forward_hooks:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants