Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed checkpointing user guide #9494

Open
wants to merge 5 commits into
base: yuya/add_checkpoints_section
Choose a base branch
from

Conversation

mikolajblaz
Copy link
Collaborator

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Collection: [Note which collection this PR will affect]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Mikołaj Błaż <[email protected]>

# Distributed checkpoint save
sharded_state_dict = {
'weight': dist_checkpointing.ShardedTensor.from_rank_offsets('weight', local_ten, (0, rank, world_size))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment explain (0, rank, world_size)

4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_)
5. All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see `Save and load strategies`_)
6. All ShardedTensors are saved.
7. `metadata.json` file with backend and version metadata is saved to the checkpoint directory.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe put a link to source code here to show where those steps happen

The sharded state dict is processed in the following way:

1. The ShardedTensorFactories are applied
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalNonPersistentObject wasn't explained. What's this?

return dist_checkpointing.load(sharded_state_dict, ckpt_dir, fully_parallel_load_strategy)


The `dist_checkpointing` package provides default strategies for some sharded backends, so it's enough to specify a tuple `(backend, version)` as a saving strategy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • explain what are backends and versions here?

from megatron.core.dist_checkpointing.strategies.torch import TorchDistLoadShardedStrategy, TorchDistSaveShardedStrategy
from megatron.core.dist_checkpointing.strategies.fully_parallel import FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper

base_save_strategy = TorchDistSaveShardedStrategy('torch_dist', 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments

The `dist_checkpointing` package provides default strategies for some sharded backends, so it's enough to specify a tuple `(backend, version)` as a saving strategy.
Backends and versions are stored in a `metadata.json` file inside the checkpoint so that the loading strategy can be determined automatically (provided that there exists a default loading strategy for a given backend and version).

For "sharded" strategies, currently the backends supported by default are based on `torch.distributed.checkpoint` format (`torch_dist` backend) and Zarr format (`zarr` backend).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a bit more to explain the difference?

Note: in order to reuse model SharderTensors to create optimizer ShardedTensors, the model **SharderTensors must wrap model parameters**, not just tensors
(obtaining a state dict with model parameters can be achieved by passing `keep_vars=True` to the model `state_dict` function).
Otherwise the correspondence between model ShardedTensors and optimizer states is impossible to recreate.
This is the reason for introducing ShardedTensorFactories - we have to register the original model parameter as `ShardedTensorFactories.data` and apply any subsequent transformations as a factory function in order to make sure that the same transformation can be applied to the optimizer states.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

show an example source code in mcore if there's any


Extra flattening comes with an efficiency challenge during checkpoint resharding.
Since flattening is applied after the global tensors is sharded into the grid of local chunks, loading after resharding requires accessing incontiguous data fragments.
An example solution for that is implemented in the `dist_checkpointing/strategies/resharding.py` module and involves saving the flattened tensor with a different global shape than the original one.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use github path

* - 3
- [5, 9]
* - 5
- [10, 11]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why DP affects the local shards?

------------

Model parallel training requires parallelism-aware checkpointing.
Megatron-Core provides a checkpointing library capable of handling all types of parallelisms used in LLMs training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Megatron Core provides a checkpointing library capable of handling all types of parallelisms used in LLM training.


Model parallel training requires parallelism-aware checkpointing.
Megatron-Core provides a checkpointing library capable of handling all types of parallelisms used in LLMs training.
Although distributed checkpointing library is targeted at Megatron-Core model, it can be used with other models as well, provided an appropriate integration.
Copy link
Collaborator

@jgerh jgerh Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the distributed checkpointing library is targeted for the Megatron Core model, it can also be used with other models, as long as proper integration is implemented.

Although distributed checkpointing library is targeted at Megatron-Core model, it can be used with other models as well, provided an appropriate integration.

The library provides two main entrypoints: `dist_checkpointing.save` and `dist_checkpointing.load` which are meant to replace the `torch.save` and `torch.load` in the regular checkpointing flow.
Apart from that it provides mechanism to define different types of local tensors placement in the global checkpoint.
Copy link
Collaborator

@jgerh jgerh Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from that, it provides a mechanism to define the different types of local tensors placement in the global checkpoint.

Apart from that it provides mechanism to define different types of local tensors placement in the global checkpoint.


Basic sharding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basic Sharding

# For some distributed checkpoint backends this is actually what happens underneath.


Supported entities
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supported Entities

-------------
It's the primary use case of distributed checkpointing - tensors sharding.
Allows to define how PyTorch tensors are sharded across the workload.
See `Tensors transformations`_ section for more details on ShardedTensors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the Tensors transformations_ section for more details on ShardedTensors.

Comment on lines +160 to +162
This class allows to defer tensors transformations until the actual saving.
A factory can expand a tensor into an arbitrary sub state dict (including all supported entities listed above).
The need for such deferral will be explained in the `Tensors transformations`_ section.
Copy link
Collaborator

@jgerh jgerh Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ShardedTensorFactory class defers tensors transformations until they are actually saved.

This is a simple wrapper that allows to express the fact that the object wrapped with this class should end up in the final loaded state dict during loading.
During saving such objects are ignored.

Arbitrary object
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arbitrary Object




Entrypoints
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Entry Points


Entrypoints
===========
There are several useful user entrypoints for checkpoint saving and loading.
Copy link
Collaborator

@jgerh jgerh Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several useful user entry points for checkpoint saving and loading.

Requires providing a sharded state dict to save and saving strategies for handling different entities (see `Save and load strategies`_ for detailed explanation).
The sharded state dict is processed in the following way:

1. The ShardedTensorFactories are applied
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The ShardedTensorFactories are applied.

The sharded state dict is processed in the following way:

1. The ShardedTensorFactories are applied
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored
Copy link
Collaborator

@jgerh jgerh Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The LocalNonPersistentObject is extracted from the sharded state dict and ignored.


1. The ShardedTensorFactories are applied
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored
3. ShardedBase objects are extracted
Copy link
Collaborator

@jgerh jgerh Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The ShardedBase objects are extracted.

1. The ShardedTensorFactories are applied
2. LocalNonPersistentObject are extracted from the sharded state dict and ignored
3. ShardedBase objects are extracted
4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. All other objects are treated as "common" and saved according to a sharded strategy (see Save and load strategies_).

2. LocalNonPersistentObject are extracted from the sharded state dict and ignored
3. ShardedBase objects are extracted
4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_)
5. All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see `Save and load strategies`_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see Save and load strategies_).

Requires providing a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies.
In practice, the same sharded state dict can be usually used for both saving and loading (the sharded state dict for loading will just contain tensors with uninitialized data).

The sharded state dict provided as an input is processed in the following way:
Copy link
Collaborator

@jgerh jgerh Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the sharded state dict is provided as input, it is processed in the following way:


The sharded state dict provided as an input is processed in the following way:

1. "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict
Copy link
Collaborator

@jgerh jgerh Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict.

4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_)
5. All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see `Save and load strategies`_)
6. All ShardedTensors are saved.
7. `metadata.json` file with backend and version metadata is saved to the checkpoint directory.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The metadata.json file with backend and version metadata is saved to the checkpoint directory.

The sharded state dict provided as an input is processed in the following way:

1. "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict
2. The ShardedTensorFactories from the input sharded state dict are applied
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The ShardedTensorFactories from the input sharded state dict are applied.


1. "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict
2. The ShardedTensorFactories from the input sharded state dict are applied
3. LocalNonPersistentObject are extracted from the input sharded state dict, unwrapped and added to the resulting state dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The LocalNonPersistentObject is extracted from the input sharded state dict, unwrapped and added to the resulting state dict.

1. "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict
2. The ShardedTensorFactories from the input sharded state dict are applied
3. LocalNonPersistentObject are extracted from the input sharded state dict, unwrapped and added to the resulting state dict
4. ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict.

2. The ShardedTensorFactories from the input sharded state dict are applied
3. LocalNonPersistentObject are extracted from the input sharded state dict, unwrapped and added to the resulting state dict
4. ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict
5. ShardedTensors are extracted and loaded from the checkpoint into the resulting state dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The ShardedTensors are extracted and loaded from the checkpoint into the resulting state dict.

3. LocalNonPersistentObject are extracted from the input sharded state dict, unwrapped and added to the resulting state dict
4. ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict
5. ShardedTensors are extracted and loaded from the checkpoint into the resulting state dict
6. Factory merges are applied (see `Optimizers`_ for explanation)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Factory merges are applied (see Optimizers_ for explanation).

* - 5
- [10, 11]

The same tensor after sharding by TP=6, flattening and sharding by DP=1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After sharding by TP=6 and flattening and sharding by DP=1, the resulting local shards are as follows:

- [5, 11]


Arbitrary transformations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arbitrary Transformations

For example, if the model weights are supposed to be transposed in the checkpoint, it's almost impossible to implement a performant factory function that is capable of transposing a flattened and sliced tensor, because the flattening and slicing should happen in the transposed dimension.


Application integration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Application Integration

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the intro.rst file and provided edits here: #9503

Basic sharding
--------------

The main way to define relationship of a plain local PyTorch tensor to tensors on other ranks is by wrapping it in a `ShardedTensor` class.
Copy link
Collaborator

@jgerh jgerh Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main way to define the relationship of a plain, local PyTorch tensor to tensors on other ranks is by wrapping it in a ShardedTensor class.

A factory can expand a tensor into an arbitrary sub state dict (including all supported entities listed above).
The need for such deferral will be explained in the `Tensors transformations`_ section.

LocalNonpersitentObject
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalNonpersistentObject


LocalNonpersitentObject
-----------------------
This is a simple wrapper that allows to express the fact that the object wrapped with this class should end up in the final loaded state dict during loading.
Copy link
Collaborator

@jgerh jgerh Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalNonpersistentObject is a simple wrapper indicating that the object wrapped with this class should end up in the final loaded state dict during loading.


Arbitrary object
----------------
All objects different than dicts, lists and the instances of the classes listed above are treated as "common" objects.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All objects different than dicts, lists, and the instances of the classes listed above are treated as "common" objects.

----------------
All objects different than dicts, lists and the instances of the classes listed above are treated as "common" objects.

During saving, all such objects in the sharded state dict passed to `dist_checkpointing.save` are assumed to be duplicated across ranks and therefore saved only by a single coordinator rank (rank 0).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During saving, all such objects in the sharded state dict passed to dist_checkpointing.save are assumed to be duplicated across ranks. Therefore, they are saved only by a single coordinator rank (rank 0).


dist_checkpointing.load
-----------------------
The main entrypoint for checkpoint loading.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dist_checkpointing.load function is the main entry point for checkpoint loading.

dist_checkpointing.load
-----------------------
The main entrypoint for checkpoint loading.
Requires providing a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It requires providing a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies.


Optimizers
==========
This module gives helper tools to the user to simplify constructing ShardedTensors for optimizer states.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Optimizers module provides helper tools to the user to simplify constructing ShardedTensors for optimizer states.

Optimizers
==========
This module gives helper tools to the user to simplify constructing ShardedTensors for optimizer states.
The ShardedTensors that define local to sharded tensors mapping for model parameters should be reused for optimizer states to avoid code duplication.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ShardedTensors that define local-to-sharded tensors mapping for model parameters should be reused for optimizer states to avoid code duplication.

This should support most optimizer cases, but some of them might require custom sharded state dict creation.
A good example is a Distributed Optimizer which flattens the parameters - see `Tensors transformations`_ section for more details.

Note: in order to reuse model SharderTensors to create optimizer ShardedTensors, the model **SharderTensors must wrap model parameters**, not just tensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: In order to reuse model SharderTensors to create optimizer ShardedTensors, the model SharderTensors must wrap model parameters, not just tensors


Shape mismatch
--------------
The `allow_shape_mismatch` flag allows to relax the requirement of matching global tensor shapes during loading.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The allow_shape_mismatch flag relaxes the requirement of matching global tensor shapes during loading.


Flattening
----------
The `flattened_range` attribute allows to declare the fact that `ShardedTensor.data` is actually a slice of a flattened model parameter.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flattened_range attribute declares that ShardedTensor.data represents a slice of a flattened model parameter.


Arbitrary transformations
-------------------------
The way to apply arbitrary transformations to the tensors during saving and loading is with ShardedTensorFactory, which allows to define such transformations as a function that can be reapplied to any ShardedTensor (in particular, a ShardedTensor representing optimizer states).
Copy link
Collaborator

@jgerh jgerh Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way to apply arbitrary transformations to the tensors during saving and loading is with ShardedTensorFactory. It defines such transformations as a function that can be reapplied to any ShardedTensor (in particular, a ShardedTensor representing optimizer states).

In order to apply such transformation both to model and optimizer parameters in a consistent manner, it's necessary to encode them as factory functions (with original model parameter as the `data` input so that the optimizer params can be properly mapped to model ShardedTensors).

Note that implementing some transformations might be challenging or impossible while supporting flattening for a Distributed Optimizer case.
For example, if the model weights are supposed to be transposed in the checkpoint, it's almost impossible to implement a performant factory function that is capable of transposing a flattened and sliced tensor, because the flattening and slicing should happen in the transposed dimension.
Copy link
Collaborator

@jgerh jgerh Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if the model weights are supposed to be transposed in the checkpoint, it's almost impossible to implement a performant factory function that is capable of transposing a flattened and sliced tensor. This is because the flattening and slicing should happen in the transposed dimension.

The only thing required from the application side is preparing a sharded state dict with ShardedTensors, ShardedObjects, etc. (representing the sharding of the data employed by the application)
and using the `dist_checkpointing.save` and `dist_checkpointing.load` entrypoints as replacements for `torch.save` and `torch.load`.

In Megatron-Core the sharded state dict preparation is already implemented in a `sharded_state_dict` method added to all Megatron-Core models and modules, which allows to create sharded state dicts in a composable way.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Megatron Core, the sharded state dictionary preparation is already implemented in a sharded_state_dict method which creates the sharded state dicts in a composable way.

}
dist_checkpointing.save(sharded_state_dict, dist_ckpt_root)

During load the distributed checkpoint can be easily read even if the job size changes (contrary to native checkpoints that require the same number of ranks).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During load, the distributed checkpoint can be easily read even if the job size changes (contrary to native checkpoints that require the same number of ranks).

dist_checkpointing.save(sharded_state_dict, dist_ckpt_root)

During load the distributed checkpoint can be easily read even if the job size changes (contrary to native checkpoints that require the same number of ranks).
The main difference wrt. `torch.load` is that the user has to provide the definition of the sharded state dict that needs to be loaded.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main difference with wrt. torch.load is that the user has to provide the definition of the sharded state dict that needs to be loaded.


ShardedBase
-----------
Base class for expressing any kind of sharding.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShardedBase is the base class for expressing any kind of sharding.

ShardedBase
-----------
Base class for expressing any kind of sharding.
Each sharded entity must be uniquely identified by its `key`, carry some `data` to be saved or loaded and define `replica_id` which helps identify data redundancy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each sharded entity must be uniquely identified by its key, carry some data to be saved or loaded, and define replica_id which helps identify data redundancy.


ShardedTensor
-------------
It's the primary use case of distributed checkpointing - tensors sharding.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShardedTensor is the primary use case for distributed checkpointing - tensor sharding.

ShardedTensor
-------------
It's the primary use case of distributed checkpointing - tensors sharding.
Allows to define how PyTorch tensors are sharded across the workload.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It defines how PyTorch tensors are distributed across the workload.


dist_checkpointing.save
-----------------------
The only entrypoint for checkpoint saving.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dist_checkpointing.save function is the only entry point for checkpoint saving.

dist_checkpointing.save
-----------------------
The only entrypoint for checkpoint saving.
Requires providing a sharded state dict to save and saving strategies for handling different entities (see `Save and load strategies`_ for detailed explanation).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It requires providing a sharded state dict to save and saving strategies for handling different entities (see Save and load strategies_ for detailed explanation).

@@ -0,0 +1,392 @@
Distributed checkpoints
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distributed Checkpoints

Distributed checkpoints
=======================

This guide provides details about the distributed checkpoints format from Megatron-Core.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guide provides details about the distributed checkpoints format from Megatron Core.

@mikolajblaz mikolajblaz changed the title Mblaz/docs dist ckpt Distributed checkpointing user guide Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants