Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[won't merge - v1 codebase] Bert #1543

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
fd8ac2e
Bert init commit
Zenglinxiao Jul 18, 2019
7601a88
support file
Zenglinxiao Jul 18, 2019
1c0498e
activation function
Zenglinxiao Jul 18, 2019
de3ca85
bert dataset
Zenglinxiao Jul 18, 2019
ede0250
add a new way of using bert
Zenglinxiao Jul 19, 2019
1dfa50a
merge some function
Zenglinxiao Jul 19, 2019
c1dd1f9
adapt BERT related module to ONMT habit
Zenglinxiao Jul 23, 2019
8c3436f
add downsteam task support
Zenglinxiao Jul 23, 2019
12a909a
update
Zenglinxiao Jul 25, 2019
ea14b13
update
Zenglinxiao Jul 26, 2019
3fae446
fix bug; add new feature
Zenglinxiao Aug 9, 2019
4b511e4
add prediction file
Zenglinxiao Aug 13, 2019
7f6a127
clean up code
Zenglinxiao Aug 14, 2019
892e0a0
tagging bug fix
Zenglinxiao Aug 20, 2019
ed0cf4d
clean code
Zenglinxiao Aug 26, 2019
dde55e6
Merge branch 'master' of https://github.com/OpenNMT/OpenNMT-py into bert
Zenglinxiao Aug 26, 2019
6c5ec3a
Fix flake8
Zenglinxiao Aug 26, 2019
ba8a358
solve PR check
Zenglinxiao Aug 26, 2019
08b1080
minor changes to make code simpler/more explicit
pltrdy Aug 26, 2019
b317ecf
Merge pull request #1 from pltrdy/bert
Zenglinxiao Aug 26, 2019
660e459
simplify code
Zenglinxiao Aug 27, 2019
e5b0355
fix import; clarify FAQ
Zenglinxiao Aug 27, 2019
1a676b2
fix build
Zenglinxiao Aug 27, 2019
4335a13
fix exception
Zenglinxiao Aug 27, 2019
f5aec9f
switch BertLayerNorm to offical LayerNorm, change BertAdam to AdamW w…
Zenglinxiao Sep 2, 2019
4938a93
fix bert valid step, remove unuse part in saver
Zenglinxiao Sep 4, 2019
b1658f5
add dynamic batchingwhen inference
Zenglinxiao Sep 12, 2019
9b1abd2
update classifier with confiance option
Zenglinxiao Nov 19, 2019
2e7e8d1
merge recent change on master
Zenglinxiao Nov 22, 2019
e352a94
rm tailing space
Zenglinxiao Nov 22, 2019
9d655fd
merge recent update from master
Zenglinxiao Nov 22, 2019
6c8e8e6
fix travis
Zenglinxiao Nov 22, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ addons:
before_install:
# Install CPU version of PyTorch.
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install torch==1.2.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html; fi
- pip install --upgrade setuptools
- pip install -r requirements.opt.txt
- python setup.py install
env:
Expand Down
146 changes: 146 additions & 0 deletions bert_ckp_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python
""" Convert weights of huggingface Bert to onmt Bert"""
from argparse import ArgumentParser
import torch
from onmt.encoders.bert import BertEncoder
from onmt.models.bert_generators import BertPreTrainingHeads
from onmt.modules.bert_embeddings import BertEmbeddings
from collections import OrderedDict
import re


def decrement(matched):
value = int(matched.group(1))
if value < 1:
raise ValueError('Value Error when converting string')
string = "bert.encoder.layer.{}.output.LayerNorm".format(value-1)
return string


def mapping_key(key, max_layers):
if 'bert.embeddings' in key:
key = key

elif 'bert.encoder' in key:
# convert layer_norm weights
key = re.sub(r'bert.encoder.0.layer_norm\.(.*)',
r'bert.embeddings.LayerNorm.\1', key)
key = re.sub(r'bert.encoder\.(\d+)\.layer_norm',
decrement, key)
# convert attention weights
key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_keys\.(.*)',
r'bert.encoder.layer.\1.attention.self.key.\2', key)
key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_values\.(.*)',
r'bert.encoder.layer.\1.attention.self.value.\2', key)
key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_query\.(.*)',
r'bert.encoder.layer.\1.attention.self.query.\2', key)
key = re.sub(r'bert.encoder\.(\d+)\.self_attn.final_linear\.(.*)',
r'bert.encoder.layer.\1.attention.output.dense.\2', key)
# convert feed forward weights
key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.layer_norm\.(.*)',
r'bert.encoder.layer.\1.attention.output.LayerNorm.\2',
key)
key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.w_1\.(.*)',
r'bert.encoder.layer.\1.intermediate.dense.\2', key)
key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.w_2\.(.*)',
r'bert.encoder.layer.\1.output.dense.\2', key)

elif 'bert.layer_norm' in key:
key = re.sub(r'bert.layer_norm',
r'bert.encoder.layer.' + str(max_layers - 1) +
'.output.LayerNorm', key)
elif 'bert.pooler' in key:
key = key
elif 'generator.next_sentence' in key:
key = re.sub(r'generator.next_sentence.linear\.(.*)',
r'cls.seq_relationship.\1', key)
elif 'generator.mask_lm' in key:
key = re.sub(r'generator.mask_lm.bias',
r'cls.predictions.bias', key)
key = re.sub(r'generator.mask_lm.decode.weight',
r'cls.predictions.decoder.weight', key)
key = re.sub(r'generator.mask_lm.transform.dense\.(.*)',
r'cls.predictions.transform.dense.\1', key)
key = re.sub(r'generator.mask_lm.transform.layer_norm\.(.*)',
r'cls.predictions.transform.LayerNorm.\1', key)
else:
raise KeyError("Unexpected keys! Please provide HuggingFace weights")
return key


def convert_bert_weights(bert_model, weights, n_layers=12):
bert_model_keys = bert_model.state_dict().keys()
bert_weights = OrderedDict()
generator_weights = OrderedDict()
model_weights = {"bert": bert_weights,
"generator": generator_weights}
hugface_keys = weights.keys()
try:
for key in bert_model_keys:
hugface_key = mapping_key(key, n_layers)
if hugface_key not in hugface_keys:
if 'LayerNorm' in hugface_key:
# Fix LayerNorm of old huggingface ckp
hugface_key = re.sub(r'LayerNorm.weight',
r'LayerNorm.gamma', hugface_key)
hugface_key = re.sub(r'LayerNorm.bias',
r'LayerNorm.beta', hugface_key)
if hugface_key in hugface_keys:
print("[OLD Weights file]gamma/beta is used in " +
"naming BertLayerNorm. Mapping succeed.")
else:
raise KeyError("Failed fix LayerNorm %s, check file"
% hugface_key)
else:
raise KeyError("Mapped key %s not in weight file"
% hugface_key)
if 'generator' not in key:
onmt_key = re.sub(r'bert\.(.*)', r'\1', key)
model_weights['bert'][onmt_key] = weights[hugface_key]
else:
onmt_key = re.sub(r'generator\.(.*)', r'\1', key)
model_weights['generator'][onmt_key] = weights[hugface_key]
except KeyError:
print("Unsuccessful convert.")
raise
return model_weights


def main():
parser = ArgumentParser()
parser.add_argument("--layers", type=int, default=None, required=True)

parser.add_argument("--bert_model_weights_file", "-i", type=str,
default=None, required=True, help="Path to the "
"huggingface Bert weights file download from "
"https://github.com/huggingface/pytorch-transformers")

parser.add_argument("--output_name", "-o", type=str,
default=None, required=True,
help="output onmt version Bert weight file Path")
args = parser.parse_args()

print("Model contain {} layers.".format(args.layers))

print("Load weights from {}.".format(args.bert_model_weights_file))

bert_weights = torch.load(args.bert_model_weights_file)
embeddings = BertEmbeddings(28996) # vocab don't bother the conversion
bert_encoder = BertEncoder(embeddings)
generator = BertPreTrainingHeads(bert_encoder.d_model,
embeddings.vocab_size)
bertlm = torch.nn.Sequential(OrderedDict([
('bert', bert_encoder),
('generator', generator)]))
model_weights = convert_bert_weights(bertlm, bert_weights, args.layers)

ckp = {'model': model_weights['bert'],
'generator': model_weights['generator']}

outfile = args.output_name
print("Converted weights file in {}".format(outfile))
torch.save(ckp, outfile)


if __name__ == '__main__':
main()
144 changes: 144 additions & 0 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,150 @@ will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`,

**Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing.

## How do I use BERT?
BERT is a general-purpose "language understanding" model introduced by Google, it can be used for various downstream NLP tasks and easily adapted into a new task using transfer learning. Using BERT has two stages: Pre-training and fine-tuning. But as the Pre-training is super expensive, we do not recommand you to pre-train a BERT from scratch. Instead loading weights from a existing pretrained model and fine-tuning is suggested. Currently we support sentence(-pair) classification and token tagging downstream task.

### Use pretrained BERT weights
To use weights from a existing huggingface's pretrained model, we provide you a script to convert huggingface's BERT model weights into ours.

Usage:
```bash
python bert_ckp_convert.py --layers NUMBER_LAYER
--bert_model_weights_file HUGGINGFACE_BERT_WEIGHTS
--output_name OUTPUT_FILE
```
* Go to [modeling_bert.py](https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py) to check all available pretrained model.

### Preprocess train/dev dataset
To generate train/dev data for BERT, you can use preprocess_bert.py by providing raw data in certain format and choose a BERT Tokenizer model `-vm` coherent with pretrained model.
#### Classification
For classification dataset, we support input file in csv or plain text file format.

* For csv file, each line should contain a instance with one or two sentence column and one column for label as in GLUE dataset, other csv format dataset should be compatible. A typical csv file should be like:

| ID | SENTENCE_A | SENTENCE_B(Optional) | LABEL |
| -- | ------------------------ | ------------------------ | ------- |
| 0 | sentence a of instance 0 | sentence b of instance 0 | class 2 |
| 1 | sentence a of instance 1 | sentence b of instance 1 | class 1 |
| ...| ... | ... | ... |

Then calling `preprocess_bert.py` and providing input sentence columns and label column:
```bash
python preprocess_bert.py --task classification --corpus_type {train, valid}
--file_type csv [--delimiter '\t'] [--skip_head]
--input_columns 1 2 --label_column 3
--data DATA_DIR/FILENAME.tsv
--save_data dataset
-vm bert-base-cased --max_seq_len 256 [--do_lower_case]
[--sort_label_vocab] [--do_shuffle]
```
* For plain text format, we accept multiply files as input, each file contains instances for one specific class. Each line of the file contains one instance which could be composed by one sentence or two separated by ` ||| `. All input file should be arranged in following way:
```
.
├── LABEL_A
│   └── FILE_WITH_INSTANCE_A
└── LABEL_B
└── FILE_WITH_INSTANCE_B
```
Then call `preprocess_bert.py` as following to generate training data:
```bash
python preprocess_bert.py --task classification --corpus_type {'train', 'valid'}
--file_type txt [--delimiter ' ||| ']
--data DIR_BASE/LABEL_1/FILENAME1 ... DIR_BASE/LABEL_N/FILENAME2
--save_data dataset
--vocab_model {bert-base-uncased,...}
--max_seq_len 256 [--do_lower_case]
[--sort_label_vocab] [--do_shuffle]
```
#### Tagging
For tagging dataset, we support input file in plain text file format.

Each line of the input file should contain one token and its tagging, different fields should be separated by a delimiter(default space) while sentences are separated by a blank line.

A example of input file is given below (`Token X X Label`):
```
-DOCSTART- -X- O O

CRICKET NNP I-NP O
- : O O
LEICESTERSHIRE NNP I-NP I-ORG
TAKE NNP I-NP O
OVER IN I-PP O
AT NNP I-NP O
TOP NNP I-NP O
AFTER NNP I-NP O
INNINGS NNP I-NP O
VICTORY NN I-NP O
. . O O

LONDON NNP I-NP I-LOC
1996-08-30 CD I-NP O

```
Then call preprocess_bert.py providing token column and label column as following to generate training data for token tagging task:
```bash
python preprocess_bert.py --task tagging --corpus_type {'train', 'valid'}
--file_type txt [--delimiter ' ']
--input_columns 1 --label_column 3
--data DATA_DIR/FILENAME
--save_data dataset
--vocab_model {bert-base-uncased,...}
--max_seq_len 256 [--do_lower_case]
[--sort_label_vocab] [--do_shuffle]
```
#### Pretraining objective
Even if it's not recommended, we also provide you a script to generate pretraining dataset as you may want to finetuning a existing pretrained model on masked language modeling and next sentence prediction.

The script expects a single file as input, consisting of untokenized text, with one sentence per line, and one blank line between documents.
A usage example is given below:
```bash
python pregenerate_bert_training_data.py --input_file INPUT_FILE
--output_dir OUTPUT_DIR
--output_name OUTPUT_FILE_PREFIX
--corpus_type {'train', 'valid'}
--vocab_model {bert-base-uncased,...}
[--do_lower_case] [--do_whole_word_mask] [--reduce_memory]
--epochs_to_generate 2
--max_seq_len 128
--short_seq_prob 0.1 --masked_lm_prob 0.15
--max_predictions_per_seq 20
[--save_json]
```

### Training
After preprocessed data have been generated, you can load weights from a pretrained BERT and transfer it to downstream task with a task specific output head. This task specific head will be initialized by a method you choose if there is no such architecture in weights file specified by `--train_from`. Among all available optimizers, you are suggested to use `--optim bertadam` as it is the method used to train BERT. `warmup_steps` could be set as 1% of `train_steps` as in original paper if use linear decay method.

A usage example is given below:
```bash
python train.py --is_bert --task_type {pretraining, classification, tagging}
--data PREPROCESSED_DATAIFILE
--train_from CONVERTED_CHECKPOINT.pt [--param_init 0.1]
--save_model MODEL_PREFIX --save_checkpoint_steps 1000
[--world_size 2] [--gpu_ranks 0 1]
--word_vec_size 768 --rnn_size 768
--layers 12 --heads 8 --transformer_ff 3072
--activation gelu --dropout 0.1 --average_decay 0.0001
--batch_size 8 [--accum_count 4] --optim bertadam [--max_grad_norm 0]
--learning_rate 2e-5 --learning_rate_decay 0.99 --decay_method linear
--train_steps 4000 --valid_steps 200 --warmup_steps 40
[--report_every 10] [--seed 3435]
[--tensorboard] [--tensorboard_log_dir LOGDIR]
```

### Predicting
After training, you can use `predict.py` to generate predicting for raw file. Make sure to use the same BERT Tokenizer model `--vocab_model` as in training data.

For classification task, file to be predicted should be one sentence(-pair) a line with ` ||| ` separating sentence.
For tagging task, each line should be a tokenized sentence with tokens separated by space.

Usage:
```bash
python predict.py --task {classification, tagging}
--model ONMT_BERT_CHECKPOINT.pt
--vocab_model bert-base-uncased [--do_lower_case]
--data DATA_2_PREDICT [--delimiter {' ||| ', ' '}] --max_seq_len 256
--output PREDICT.txt [--batch_size 8] [--gpu 1] [--seed 3435]
```
## Can I get word alignment while translating?

### Raw alignments from averaging Transformer attention heads
Expand Down
49 changes: 49 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,55 @@ @article{DBLP:journals/corr/MartinsA16
bibsource = {dblp computer science bibliography, https://dblp.org}
}

@article{DBLP:journals/corr/abs-1711-05101,
author = {Ilya Loshchilov and
Frank Hutter},
title = {Fixing Weight Decay Regularization in Adam},
journal = {CoRR},
volume = {abs/1711.05101},
year = {2017},
url = {http://arxiv.org/abs/1711.05101},
archivePrefix = {arXiv},
eprint = {1711.05101},
timestamp = {Mon, 13 Aug 2018 16:48:18 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1711-05101},
bibsource = {dblp computer science bibliography, https://dblp.org}
}

@article{DBLP:journals/corr/abs-1810-04805,
author = {Jacob Devlin and
Ming{-}Wei Chang and
Kenton Lee and
Kristina Toutanova},
title = {{BERT:} Pre-training of Deep Bidirectional Transformers for Language
Understanding},
journal = {CoRR},
volume = {abs/1810.04805},
year = {2018},
url = {http://arxiv.org/abs/1810.04805},
archivePrefix = {arXiv},
eprint = {1810.04805},
timestamp = {Tue, 30 Oct 2018 20:39:56 +0100},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1810-04805},
bibsource = {dblp computer science bibliography, https://dblp.org}
}

@article{DBLP:journals/corr/HendrycksG16,
author = {Dan Hendrycks and
Kevin Gimpel},
title = {Bridging Nonlinearities and Stochastic Regularizers with Gaussian
Error Linear Units},
journal = {CoRR},
volume = {abs/1606.08415},
year = {2016},
url = {http://arxiv.org/abs/1606.08415},
archivePrefix = {arXiv},
eprint = {1606.08415},
timestamp = {Mon, 13 Aug 2018 16:46:20 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/HendrycksG16},
bibsource = {dblp computer science bibliography, https://dblp.org}
}

@inproceedings{garg2019jointly,
title = {Jointly Learning to Align and Translate with Transformer Models},
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
Expand Down
2 changes: 1 addition & 1 deletion onmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from __future__ import division, print_function

import onmt.inputters
import onmt.models
import onmt.encoders
import onmt.decoders
import onmt.models
import onmt.utils
import onmt.modules
from onmt.trainer import Trainer
Expand Down
Loading