Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] LAMB optimizer #1460

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def train_opts(parser):
nargs="*", default=None,
help='Criteria to use for early stopping.')
group.add('--optim', '-optim', default='sgd',
choices=['sgd', 'adagrad', 'adadelta', 'adam',
choices=['sgd', 'adagrad', 'adadelta', 'adam', 'lamb',
'sparseadam', 'adafactor', 'fusedadam'],
help="Optimization method.")
group.add('--adagrad_accumulator_init', '-adagrad_accumulator_init',
Expand Down Expand Up @@ -466,6 +466,14 @@ def train_opts(parser):
'suggested a value of 0.98 for beta2, this parameter may '
'not work well for normal models / default '
'baselines.')
group.add('--lamb_beta1', '-lamb_beta1', type=float, default=0.9,
help="The beta1 parameter used by Lamb.")
group.add('--lamb_beta2', '-lamb_beta2', type=float, default=0.999,
help="The beta2 parameter used by Lamb.")
group.add('--lamb_eps', '-lamb_eps', type=float, default=1e-8,
help="The epsilon parameter used by Lamb.")
group.add('--lamb_wd', '-lamb_wd', type=float, default=0.0,
help="The weight decay parameter used by Lamb.")
group.add('--label_smoothing', '-label_smoothing', type=float, default=0.0,
help="Label smoothing value epsilon. "
"Probabilities of all non-true labels "
Expand Down
116 changes: 116 additions & 0 deletions onmt/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
from copy import copy
from math import sqrt
import math

from onmt.utils.misc import fn_args

Expand Down Expand Up @@ -82,6 +83,13 @@ def build_torch_optimizer(model, opt):
params,
lr=opt.learning_rate,
betas=betas)
elif opt.optim == 'lamb':
optimizer = Lamb(
params,
lr=opt.learning_rate,
betas=(opt.lamb_beta1, opt.lamb_beta2),
eps=opt.lamb_eps,
weight_decay=opt.lamb_wd)
else:
raise ValueError('Invalid optimizer type: ' + opt.optim)

Expand Down Expand Up @@ -517,3 +525,111 @@ def step(self, closure=None):
p.data.add_(-group['weight_decay'] * lr_t, p.data)

return loss


class Lamb(torch.optim.Optimizer):
"""Implements Lamb algorithm.
Based on https://github.com/cybertronai/pytorch-lamb
which is itself based on `torch.optimizers.Adam`.
It has been proposed in `Reducing BERT Pre-Training Time
from 3 Days to 76 Minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or
dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used
for computing running averages of gradient and
its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1,
which turns this into Adam. Useful for comparison purposes.
.. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes:
https://arxiv.org/abs/1904.00962
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, adam=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".
format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".
format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
super(Lamb, self).__init__(params, defaults)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Lamb does not support sparse gradients,"
"consider SparseAdam instead.")

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

# in the paper, exp_avg is m_t and exp_avg_sq is v_t
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)

# m = beta1 * m + (1 - beta1) * grad
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v = beta2 * m + (1 - beta2) * grad**2
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group['lr'] * \
math.sqrt(bias_correction2) / bias_correction1

adam_step = exp_avg / denom
# L2 norm uses sum, but here since we're dividing,
# use mean to avoid overflow.
r1 = p.data.pow(2).mean().sqrt()
r2 = adam_step.pow(2).mean().sqrt()
r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10)
state['r1'] = r1
state['r2'] = r2
state['r'] = r
if self.adam:
r = 1

p.data.add_(-step_size * r, adam_step)

return loss