-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add and Remove ZeRO 3 Hooks #5658
base: master
Are you sure you want to change the base?
Conversation
@jomayeri, please add some unit tests |
|
||
|
||
@contextmanager | ||
def unwrap_model_for_generation(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to use general naming that describes what the utility does rather than specific usage like generation.
def unwrap_model_for_generation(model): | |
def unshard_and_remove_hooks(model): |
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): | ||
optimizer_offload = model.optimizer.parameter_offload | ||
elif model.optimizer is not None: | ||
optimizer_offload = model.optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this zero-3 case where we have hooks attached to the optimizer?
elif model.optimizer is not None: | ||
optimizer_offload = model.optimizer | ||
|
||
for hook in optimizer_offload.forward_hooks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hooks are associated with parameters not optimizer, so this naming is a bit confusing. Let's clarify this.
for hook in optimizer_offload.forward_hooks: | |
for hook in parameter_offload.forward_hooks: |
Gives the ability to add and remove the forward hooks in ZeRO 3 by using a context manager. These code changes were taken from a Huggingface PR and integrated for direct support in DeepSpeed.
This is useful in the inference case and the speedup can be observed here.