-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed checkpointing user guide #9494
base: yuya/add_checkpoints_section
Are you sure you want to change the base?
Conversation
Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Mikołaj Błaż <[email protected]>
|
||
# Distributed checkpoint save | ||
sharded_state_dict = { | ||
'weight': dist_checkpointing.ShardedTensor.from_rank_offsets('weight', local_ten, (0, rank, world_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment explain (0, rank, world_size)
4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_) | ||
5. All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see `Save and load strategies`_) | ||
6. All ShardedTensors are saved. | ||
7. `metadata.json` file with backend and version metadata is saved to the checkpoint directory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe put a link to source code here to show where those steps happen
The sharded state dict is processed in the following way: | ||
|
||
1. The ShardedTensorFactories are applied | ||
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LocalNonPersistentObject
wasn't explained. What's this?
return dist_checkpointing.load(sharded_state_dict, ckpt_dir, fully_parallel_load_strategy) | ||
|
||
|
||
The `dist_checkpointing` package provides default strategies for some sharded backends, so it's enough to specify a tuple `(backend, version)` as a saving strategy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- explain what are backends and versions here?
from megatron.core.dist_checkpointing.strategies.torch import TorchDistLoadShardedStrategy, TorchDistSaveShardedStrategy | ||
from megatron.core.dist_checkpointing.strategies.fully_parallel import FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper | ||
|
||
base_save_strategy = TorchDistSaveShardedStrategy('torch_dist', 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments
The `dist_checkpointing` package provides default strategies for some sharded backends, so it's enough to specify a tuple `(backend, version)` as a saving strategy. | ||
Backends and versions are stored in a `metadata.json` file inside the checkpoint so that the loading strategy can be determined automatically (provided that there exists a default loading strategy for a given backend and version). | ||
|
||
For "sharded" strategies, currently the backends supported by default are based on `torch.distributed.checkpoint` format (`torch_dist` backend) and Zarr format (`zarr` backend). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a bit more to explain the difference?
Note: in order to reuse model SharderTensors to create optimizer ShardedTensors, the model **SharderTensors must wrap model parameters**, not just tensors | ||
(obtaining a state dict with model parameters can be achieved by passing `keep_vars=True` to the model `state_dict` function). | ||
Otherwise the correspondence between model ShardedTensors and optimizer states is impossible to recreate. | ||
This is the reason for introducing ShardedTensorFactories - we have to register the original model parameter as `ShardedTensorFactories.data` and apply any subsequent transformations as a factory function in order to make sure that the same transformation can be applied to the optimizer states. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
show an example source code in mcore if there's any
|
||
Extra flattening comes with an efficiency challenge during checkpoint resharding. | ||
Since flattening is applied after the global tensors is sharded into the grid of local chunks, loading after resharding requires accessing incontiguous data fragments. | ||
An example solution for that is implemented in the `dist_checkpointing/strategies/resharding.py` module and involves saving the flattened tensor with a different global shape than the original one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use github path
* - 3 | ||
- [5, 9] | ||
* - 5 | ||
- [10, 11] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why DP affects the local shards?
------------ | ||
|
||
Model parallel training requires parallelism-aware checkpointing. | ||
Megatron-Core provides a checkpointing library capable of handling all types of parallelisms used in LLMs training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Megatron Core provides a checkpointing library capable of handling all types of parallelisms used in LLM training.
|
||
Model parallel training requires parallelism-aware checkpointing. | ||
Megatron-Core provides a checkpointing library capable of handling all types of parallelisms used in LLMs training. | ||
Although distributed checkpointing library is targeted at Megatron-Core model, it can be used with other models as well, provided an appropriate integration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the distributed checkpointing library is targeted for the Megatron Core model, it can also be used with other models, as long as proper integration is implemented.
Although distributed checkpointing library is targeted at Megatron-Core model, it can be used with other models as well, provided an appropriate integration. | ||
|
||
The library provides two main entrypoints: `dist_checkpointing.save` and `dist_checkpointing.load` which are meant to replace the `torch.save` and `torch.load` in the regular checkpointing flow. | ||
Apart from that it provides mechanism to define different types of local tensors placement in the global checkpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from that, it provides a mechanism to define the different types of local tensors placement in the global checkpoint.
Apart from that it provides mechanism to define different types of local tensors placement in the global checkpoint. | ||
|
||
|
||
Basic sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basic Sharding
# For some distributed checkpoint backends this is actually what happens underneath. | ||
|
||
|
||
Supported entities |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supported Entities
------------- | ||
It's the primary use case of distributed checkpointing - tensors sharding. | ||
Allows to define how PyTorch tensors are sharded across the workload. | ||
See `Tensors transformations`_ section for more details on ShardedTensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the Tensors transformations
_ section for more details on ShardedTensors.
This class allows to defer tensors transformations until the actual saving. | ||
A factory can expand a tensor into an arbitrary sub state dict (including all supported entities listed above). | ||
The need for such deferral will be explained in the `Tensors transformations`_ section. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ShardedTensorFactory class defers tensors transformations until they are actually saved.
This is a simple wrapper that allows to express the fact that the object wrapped with this class should end up in the final loaded state dict during loading. | ||
During saving such objects are ignored. | ||
|
||
Arbitrary object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arbitrary Object
|
||
|
||
|
||
Entrypoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Entry Points
|
||
Entrypoints | ||
=========== | ||
There are several useful user entrypoints for checkpoint saving and loading. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several useful user entry points for checkpoint saving and loading.
Requires providing a sharded state dict to save and saving strategies for handling different entities (see `Save and load strategies`_ for detailed explanation). | ||
The sharded state dict is processed in the following way: | ||
|
||
1. The ShardedTensorFactories are applied |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The ShardedTensorFactories are applied.
The sharded state dict is processed in the following way: | ||
|
||
1. The ShardedTensorFactories are applied | ||
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The LocalNonPersistentObject is extracted from the sharded state dict and ignored.
|
||
1. The ShardedTensorFactories are applied | ||
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored | ||
3. ShardedBase objects are extracted |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The ShardedBase objects are extracted.
1. The ShardedTensorFactories are applied | ||
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored | ||
3. ShardedBase objects are extracted | ||
4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- All other objects are treated as "common" and saved according to a sharded strategy (see
Save and load strategies
_).
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored | ||
3. ShardedBase objects are extracted | ||
4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_) | ||
5. All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see `Save and load strategies`_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see Save and load strategies
_).
Requires providing a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies. | ||
In practice, the same sharded state dict can be usually used for both saving and loading (the sharded state dict for loading will just contain tensors with uninitialized data). | ||
|
||
The sharded state dict provided as an input is processed in the following way: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the sharded state dict is provided as input, it is processed in the following way:
|
||
The sharded state dict provided as an input is processed in the following way: | ||
|
||
1. "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict.
4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_) | ||
5. All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see `Save and load strategies`_) | ||
6. All ShardedTensors are saved. | ||
7. `metadata.json` file with backend and version metadata is saved to the checkpoint directory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The
metadata.json
file with backend and version metadata is saved to the checkpoint directory.
The sharded state dict provided as an input is processed in the following way: | ||
|
||
1. "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict | ||
2. The ShardedTensorFactories from the input sharded state dict are applied |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The ShardedTensorFactories from the input sharded state dict are applied.
|
||
1. "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict | ||
2. The ShardedTensorFactories from the input sharded state dict are applied | ||
3. LocalNonPersistentObject are extracted from the input sharded state dict, unwrapped and added to the resulting state dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The LocalNonPersistentObject is extracted from the input sharded state dict, unwrapped and added to the resulting state dict.
1. "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict | ||
2. The ShardedTensorFactories from the input sharded state dict are applied | ||
3. LocalNonPersistentObject are extracted from the input sharded state dict, unwrapped and added to the resulting state dict | ||
4. ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict.
2. The ShardedTensorFactories from the input sharded state dict are applied | ||
3. LocalNonPersistentObject are extracted from the input sharded state dict, unwrapped and added to the resulting state dict | ||
4. ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict | ||
5. ShardedTensors are extracted and loaded from the checkpoint into the resulting state dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The ShardedTensors are extracted and loaded from the checkpoint into the resulting state dict.
3. LocalNonPersistentObject are extracted from the input sharded state dict, unwrapped and added to the resulting state dict | ||
4. ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict | ||
5. ShardedTensors are extracted and loaded from the checkpoint into the resulting state dict | ||
6. Factory merges are applied (see `Optimizers`_ for explanation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Factory merges are applied (see
Optimizers
_ for explanation).
* - 5 | ||
- [10, 11] | ||
|
||
The same tensor after sharding by TP=6, flattening and sharding by DP=1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After sharding by TP=6 and flattening and sharding by DP=1, the resulting local shards are as follows:
- [5, 11] | ||
|
||
|
||
Arbitrary transformations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arbitrary Transformations
For example, if the model weights are supposed to be transposed in the checkpoint, it's almost impossible to implement a performant factory function that is capable of transposing a flattened and sliced tensor, because the flattening and slicing should happen in the transposed dimension. | ||
|
||
|
||
Application integration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Application Integration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the intro.rst file and provided edits here: #9503
Basic sharding | ||
-------------- | ||
|
||
The main way to define relationship of a plain local PyTorch tensor to tensors on other ranks is by wrapping it in a `ShardedTensor` class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main way to define the relationship of a plain, local PyTorch tensor to tensors on other ranks is by wrapping it in a ShardedTensor
class.
A factory can expand a tensor into an arbitrary sub state dict (including all supported entities listed above). | ||
The need for such deferral will be explained in the `Tensors transformations`_ section. | ||
|
||
LocalNonpersitentObject |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LocalNonpersistentObject
|
||
LocalNonpersitentObject | ||
----------------------- | ||
This is a simple wrapper that allows to express the fact that the object wrapped with this class should end up in the final loaded state dict during loading. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LocalNonpersistentObject is a simple wrapper indicating that the object wrapped with this class should end up in the final loaded state dict during loading.
|
||
Arbitrary object | ||
---------------- | ||
All objects different than dicts, lists and the instances of the classes listed above are treated as "common" objects. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All objects different than dicts, lists, and the instances of the classes listed above are treated as "common" objects.
---------------- | ||
All objects different than dicts, lists and the instances of the classes listed above are treated as "common" objects. | ||
|
||
During saving, all such objects in the sharded state dict passed to `dist_checkpointing.save` are assumed to be duplicated across ranks and therefore saved only by a single coordinator rank (rank 0). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During saving, all such objects in the sharded state dict passed to dist_checkpointing.save
are assumed to be duplicated across ranks. Therefore, they are saved only by a single coordinator rank (rank 0).
|
||
dist_checkpointing.load | ||
----------------------- | ||
The main entrypoint for checkpoint loading. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dist_checkpointing.load function is the main entry point for checkpoint loading.
dist_checkpointing.load | ||
----------------------- | ||
The main entrypoint for checkpoint loading. | ||
Requires providing a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It requires providing a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies.
|
||
Optimizers | ||
========== | ||
This module gives helper tools to the user to simplify constructing ShardedTensors for optimizer states. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Optimizers module provides helper tools to the user to simplify constructing ShardedTensors for optimizer states.
Optimizers | ||
========== | ||
This module gives helper tools to the user to simplify constructing ShardedTensors for optimizer states. | ||
The ShardedTensors that define local to sharded tensors mapping for model parameters should be reused for optimizer states to avoid code duplication. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ShardedTensors that define local-to-sharded tensors mapping for model parameters should be reused for optimizer states to avoid code duplication.
This should support most optimizer cases, but some of them might require custom sharded state dict creation. | ||
A good example is a Distributed Optimizer which flattens the parameters - see `Tensors transformations`_ section for more details. | ||
|
||
Note: in order to reuse model SharderTensors to create optimizer ShardedTensors, the model **SharderTensors must wrap model parameters**, not just tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: In order to reuse model SharderTensors to create optimizer ShardedTensors, the model SharderTensors must wrap model parameters, not just tensors
|
||
Shape mismatch | ||
-------------- | ||
The `allow_shape_mismatch` flag allows to relax the requirement of matching global tensor shapes during loading. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The allow_shape_mismatch
flag relaxes the requirement of matching global tensor shapes during loading.
|
||
Flattening | ||
---------- | ||
The `flattened_range` attribute allows to declare the fact that `ShardedTensor.data` is actually a slice of a flattened model parameter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The flattened_range attribute declares that ShardedTensor.data represents a slice of a flattened model parameter.
|
||
Arbitrary transformations | ||
------------------------- | ||
The way to apply arbitrary transformations to the tensors during saving and loading is with ShardedTensorFactory, which allows to define such transformations as a function that can be reapplied to any ShardedTensor (in particular, a ShardedTensor representing optimizer states). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way to apply arbitrary transformations to the tensors during saving and loading is with ShardedTensorFactory. It defines such transformations as a function that can be reapplied to any ShardedTensor (in particular, a ShardedTensor representing optimizer states).
In order to apply such transformation both to model and optimizer parameters in a consistent manner, it's necessary to encode them as factory functions (with original model parameter as the `data` input so that the optimizer params can be properly mapped to model ShardedTensors). | ||
|
||
Note that implementing some transformations might be challenging or impossible while supporting flattening for a Distributed Optimizer case. | ||
For example, if the model weights are supposed to be transposed in the checkpoint, it's almost impossible to implement a performant factory function that is capable of transposing a flattened and sliced tensor, because the flattening and slicing should happen in the transposed dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, if the model weights are supposed to be transposed in the checkpoint, it's almost impossible to implement a performant factory function that is capable of transposing a flattened and sliced tensor. This is because the flattening and slicing should happen in the transposed dimension.
The only thing required from the application side is preparing a sharded state dict with ShardedTensors, ShardedObjects, etc. (representing the sharding of the data employed by the application) | ||
and using the `dist_checkpointing.save` and `dist_checkpointing.load` entrypoints as replacements for `torch.save` and `torch.load`. | ||
|
||
In Megatron-Core the sharded state dict preparation is already implemented in a `sharded_state_dict` method added to all Megatron-Core models and modules, which allows to create sharded state dicts in a composable way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Megatron Core, the sharded state dictionary preparation is already implemented in a sharded_state_dict method which creates the sharded state dicts in a composable way.
} | ||
dist_checkpointing.save(sharded_state_dict, dist_ckpt_root) | ||
|
||
During load the distributed checkpoint can be easily read even if the job size changes (contrary to native checkpoints that require the same number of ranks). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During load, the distributed checkpoint can be easily read even if the job size changes (contrary to native checkpoints that require the same number of ranks).
dist_checkpointing.save(sharded_state_dict, dist_ckpt_root) | ||
|
||
During load the distributed checkpoint can be easily read even if the job size changes (contrary to native checkpoints that require the same number of ranks). | ||
The main difference wrt. `torch.load` is that the user has to provide the definition of the sharded state dict that needs to be loaded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main difference with wrt. torch.load
is that the user has to provide the definition of the sharded state dict that needs to be loaded.
|
||
ShardedBase | ||
----------- | ||
Base class for expressing any kind of sharding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ShardedBase is the base class for expressing any kind of sharding.
ShardedBase | ||
----------- | ||
Base class for expressing any kind of sharding. | ||
Each sharded entity must be uniquely identified by its `key`, carry some `data` to be saved or loaded and define `replica_id` which helps identify data redundancy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each sharded entity must be uniquely identified by its key
, carry some data
to be saved or loaded, and define replica_id
which helps identify data redundancy.
|
||
ShardedTensor | ||
------------- | ||
It's the primary use case of distributed checkpointing - tensors sharding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ShardedTensor is the primary use case for distributed checkpointing - tensor sharding.
ShardedTensor | ||
------------- | ||
It's the primary use case of distributed checkpointing - tensors sharding. | ||
Allows to define how PyTorch tensors are sharded across the workload. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It defines how PyTorch tensors are distributed across the workload.
|
||
dist_checkpointing.save | ||
----------------------- | ||
The only entrypoint for checkpoint saving. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dist_checkpointing.save function is the only entry point for checkpoint saving.
dist_checkpointing.save | ||
----------------------- | ||
The only entrypoint for checkpoint saving. | ||
Requires providing a sharded state dict to save and saving strategies for handling different entities (see `Save and load strategies`_ for detailed explanation). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It requires providing a sharded state dict to save and saving strategies for handling different entities (see Save and load strategies_ for detailed explanation).
@@ -0,0 +1,392 @@ | |||
Distributed checkpoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Distributed Checkpoints
Distributed checkpoints | ||
======================= | ||
|
||
This guide provides details about the distributed checkpoints format from Megatron-Core. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This guide provides details about the distributed checkpoints format from Megatron Core.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Collection: [Note which collection this PR will affect]
Changelog
Usage
# Add a code snippet demonstrating how to use this
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information