Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: optimizer state not saved #3444

Open
chelseagzr opened this issue Apr 19, 2024 · 2 comments
Open

[Bug]: optimizer state not saved #3444

chelseagzr opened this issue Apr 19, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@chelseagzr
Copy link

chelseagzr commented Apr 19, 2024

Describe the bug

Thank you for developing and maintaining this invaluable module!

We would like to save the state of the optimizer at the end of each epoch.
The save_optimizer_state parameter of the fine_tune function seems to be designed for this purpose.
However, the state of the optimizer is not saved even if we set save_optimizer_state=True.

Thank you!

To Reproduce

%pip install scipy==1.10.1 datasets transformers torch==2.0 flair==0.13.1 

import torch
import flair
from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

# 1. get the corpus
corpus: Corpus = TREC_6()

# 2. what label do we want to predict?
label_type = 'question_class'

# 3. create the label dictionary
label_dict = corpus.make_label_dictionary(label_type=label_type)

# 4. initialize transformer document embeddings (many models are available)
document_embeddings = TransformerDocumentEmbeddings('distilbert-base-uncased', fine_tune=True)

# 5. create the text classifier
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)

# 6. initialize trainer
trainer = ModelTrainer(classifier, corpus)

# 7. run training with fine-tuning
trainer.fine_tune('resources/taggers/question-classification-with-transformer',
                  learning_rate=5.0e-5,
                  mini_batch_size=4,
                  max_epochs=10,
                  save_optimizer_state=True,
                  save_model_each_k_epochs=1
)

checkpoint = torch.load('resources/taggers/question-classification-with-transformer/model_epoch_1.pt', map_location=flair.device)

Expected behavior

When save_optimizer_state is true, the checkpoint contains the state_dict of the optimizer.

Logs and Stack traces

No response

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.13.1

Pytorch

2.0.0+cu117

Transformers

4.40.0

GPU

True

@chelseagzr chelseagzr added the bug Something isn't working label Apr 19, 2024
@helpmefindaname helpmefindaname self-assigned this May 3, 2024
@helpmefindaname
Copy link
Collaborator

Hi @chelseagzr

thank you for reporting, currently saving/loading the optimizer state is not possible, and that flag should have been removed when the trainer was reworked.

I discussed with @alanakbik to introduce this again, but change it to not only store the optimizer state but also lr-scheduler and plugin states.
I am thinking of a state that can be loaded via the trainer (e.g. trainer = ModelTrainer.load_checkpoint("checkpoint.pt") to load the states, while still allowing Classifier.load("checkpoint.pt") to load the model without training states.

@chelseagzr
Copy link
Author

Thank you for the timely response! It would be great if saving the state of a trainer can be enabled!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants