Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support serialization of compound module #2082

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

TroyGarden
Copy link
Contributor

Summary:

details

  • serialization schedule
comp.ebc does NOT require further serialization of its children
comp.ebc.embedding_bags is skipped for further serialization
comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp is resumed for serialization
comp.comp Requires further serialization of its children
comp.comp.ebc does NOT require further serialization of its children
comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp.comp is resumed for serialization
comp.comp.comp Requires further serialization of its children
comp.comp.comp.ebc does NOT require further serialization of its children
comp.comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
  • parent_fqn's children
comp.comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
)} None None
comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
)} None None
comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
  (comp): CompoundModule(
    (ebc): EmbeddingBagCollection(
      (embedding_bags): ModuleDict(
        (t1): EmbeddingBag(10, 4, mode='sum')
        (t2): EmbeddingBag(10, 4, mode='sum')
        (t3): EmbeddingBag(10, 4, mode='sum')
      )
    )
  )
)} None None

Differential Revision: D58221182

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 6, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58221182

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jun 6, 2024
Summary:

# details
* serialization schedule
```
comp.ebc does NOT require further serialization of its children
comp.ebc.embedding_bags is skipped for further serialization
comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp is resumed for serialization
comp.comp Requires further serialization of its children
comp.comp.ebc does NOT require further serialization of its children
comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp.comp is resumed for serialization
comp.comp.comp Requires further serialization of its children
comp.comp.comp.ebc does NOT require further serialization of its children
comp.comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
```
* `parent_fqn`'s children
```
comp.comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
)} None None
```
```
comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
)} None None
```
```
comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
  (comp): CompoundModule(
    (ebc): EmbeddingBagCollection(
      (embedding_bags): ModuleDict(
        (t1): EmbeddingBag(10, 4, mode='sum')
        (t2): EmbeddingBag(10, 4, mode='sum')
        (t3): EmbeddingBag(10, 4, mode='sum')
      )
    )
  )
)} None None
```

Differential Revision: D58221182
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58221182

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58221182

Summary:
Pull Request resolved: pytorch#2082

# context
* to support compound module serialization such as fpEBC and fpPEA, we modified the serializer interface a little
* the basic idea is that:
a. if a compound module A consists of child module B and C, such that module B and C have their own serializer available.
b. for serialization of module A, we can just capture it relation from B and C, and used it when deserialization
c. specifically, during the deserialization, A's children will be passed in as a dict of (child_fqn => child_module)

# design doc
[**TorchRec Composable Serializer Design**](https://docs.google.com/document/d/1WUtmzdcqZmwLd4Do8g1fQRjChnRw0ZimyUrCNhxa4nA/edit#heading=h.ezrtdguw0lwq)

# note
* in order to apply this approach, it requires that module A's construction **takes in it's children as objects**
* **DO NOT** create A's children in the A's construction

# details
* serialization schedule
```
comp.ebc does NOT require further serialization of its children
comp.ebc.embedding_bags is skipped for further serialization
comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp is resumed for serialization
comp.comp Requires further serialization of its children
comp.comp.ebc does NOT require further serialization of its children
comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp.comp is resumed for serialization
comp.comp.comp Requires further serialization of its children
comp.comp.comp.ebc does NOT require further serialization of its children
comp.comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
```
* `parent_fqn`'s children
```
comp.comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
)} None None
```
```
comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
)} None None
```
```
comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
  (comp): CompoundModule(
    (ebc): EmbeddingBagCollection(
      (embedding_bags): ModuleDict(
        (t1): EmbeddingBag(10, 4, mode='sum')
        (t2): EmbeddingBag(10, 4, mode='sum')
        (t3): EmbeddingBag(10, 4, mode='sum')
      )
    )
  )
)} None None
```

Differential Revision: D58221182
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58221182

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants