LAMB implementation
This commit is contained in:
parent
c987545592
commit
0d07a23c04
|
@ -167,3 +167,96 @@ class AdamW(Optimizer):
|
|||
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
class Lamb(Optimizer):
|
||||
""" Implements the LAMB algorithm (Layer-wise Adaptive Moments optimizer for Batch training).
|
||||
|
||||
Adapted from the huggingface/transformers ADAM optimizer
|
||||
Inspired from the Google Research implementation available in ALBERT: https://github.com/google-research/google-research/blob/master/albert/lamb_optimizer.py
|
||||
Inspired from cybertronai's PyTorch LAMB implementation: https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
|
||||
|
||||
|
||||
Parameters:
|
||||
lr (float): learning rate. Default 1e-3.
|
||||
betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
|
||||
eps (float): Adams epsilon. Default: 1e-6
|
||||
weight_decay (float): Weight decay. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
correct_bias=correct_bias)
|
||||
super(Lamb, self).__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('LAMB does not support sparse gradients.')
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
|
||||
# Inspired from cybertronai's PyTorch LAMB implementation: https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
|
||||
step_size = group['lr']
|
||||
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
|
||||
|
||||
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
||||
if group['weight_decay'] != 0:
|
||||
adam_step.add_(group['weight_decay'], p.data)
|
||||
|
||||
adam_norm = adam_step.pow(2).sum().sqrt()
|
||||
if weight_norm == 0 or adam_norm == 0:
|
||||
trust_ratio = 1
|
||||
else:
|
||||
trust_ratio = weight_norm / adam_norm
|
||||
|
||||
|
||||
state['weight_norm'] = weight_norm
|
||||
state['adam_norm'] = adam_norm
|
||||
state['trust_ratio'] = trust_ratio
|
||||
|
||||
p.data.add_(-step_size * trust_ratio, adam_step)
|
||||
return loss
|
||||
|
|
Loading…
Reference in New Issue