Skip to content

🆎 Language model training & inference for text generation with transformers using pytorch

Notifications You must be signed in to change notification settings

d1pankarmedhi/attention-transformers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Attention Transformers

python pytorch

Language model train/inference scripts written in Pytorch using a self-attention mechanism for text generation.

Text Generation with Decoder Model

This project currently showcases the ability of transformer architecture to generate characters for text generation. To do so, we are using news headlines data for training.

The final model should be able to generate random news headlines.

🏗️ Training Data

# data.txt
Natalee Holloway's suspected killer, Joran van der Sloot, admits to crime, says mother
Could Venezuela’s diaspora hold the key to its opposition primary race?
Joran van der Sloot expected to plead guilty to federal charges at Wednesday hearing
Natalee Holloway's mother tells her daughter's killer in court he has caused 'indescribable pain and harm' to her family
WeWork's inevitable retreat is here
Poland introduces border controls with EU neighbor
FBI details how van der Sloot's confession in Natalee Holloway's death came together
He is Ecuador’s youngest president-elect. What lies ahead for Daniel Noboa?

🏃‍♂️ Getting started

Setup the environment

# Clone the repository and cd into it.
git clone https://github.com/d1pankarmedhi/attention-transformers.git
cd attention-transformers

# Create and activate a virtual environment 
python -m venv .venv
source .venv/bin/activate # linux
.venv\Scripts\activate # windows

# Install the dependencies
pip install -r requirements.txt

🚂 Training

View the params.py file and make changes as required. Set the epochs and batch_size.

Make sure the data.txt file is present in the root dir.

# run the train.py file to start traing
python train.py

The model should be saved inside models/model.pt. Now you can use it for inference.

🧗‍♂️ Inference

# run the `main.py` script for text generation.
python main.py

🗒️ Sample output

Training Parameters:

# Trained on a T4 GPU
epochs = 5000
batch_size = 32
n_embd or d_model = 256
parameters = 6.403185 M 

# loss
val_loss = 1.6990

Generate Text:

Restrictions of volkay may be set to end his finds
Philippine prices earn to Gaza Student and talks with the world
Ball for braces in car car shopping as govt
Formula One to Americans Town Water Tadmil New Paper
India's ready former PM Sunhila 208 2023 ads Arid Lespeed on Humanity
Argentina votes Tesla Eras stripped on half a ban piece

Even though the output doesn't make any sense, the fact that the model was able to generate such text which looks similar to news headlines with just 5000 iterations and 6M parameters is extraordinary.

With some hyperparameter tuning, the accuracy is sure to rise and the model will be able to generate some great results.

🤝 Acknowledgments

I would like to express my gratitude to the following individuals and projects for inspiration.