196 lines
8.6 KiB
Python
196 lines
8.6 KiB
Python
#! /usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
"""AMSGrad Implementation based on the paper: "On the Convergence of Adam and Beyond" (ICLR 2018)
|
|
Article Link: https://openreview.net/pdf?id=ryQu7f-RZ
|
|
Original Implementation by: https://github.com/taki0112/AMSGrad-Tensorflow
|
|
"""
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import (control_flow_ops, math_ops, resource_variable_ops, state_ops, variable_scope)
|
|
from tensorflow.python.training import optimizer
|
|
|
|
|
|
class AMSGrad(optimizer.Optimizer):
|
|
"""Implementation of the AMSGrad optimization algorithm.
|
|
|
|
See: `On the Convergence of Adam and Beyond - [Reddi et al., 2018] <https://openreview.net/pdf?id=ryQu7f-RZ>`__.
|
|
|
|
Parameters
|
|
----------
|
|
learning_rate: float
|
|
A Tensor or a floating point value. The learning rate.
|
|
beta1: float
|
|
A float value or a constant float tensor.
|
|
The exponential decay rate for the 1st moment estimates.
|
|
beta2: float
|
|
A float value or a constant float tensor.
|
|
The exponential decay rate for the 2nd moment estimates.
|
|
epsilon: float
|
|
A small constant for numerical stability.
|
|
This epsilon is "epsilon hat" in the Kingma and Ba paper
|
|
(in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
|
use_locking: bool
|
|
If True use locks for update operations.
|
|
name: str
|
|
Optional name for the operations created when applying gradients.
|
|
Defaults to "AMSGrad".
|
|
"""
|
|
|
|
def __init__(self, learning_rate=0.01, beta1=0.9, beta2=0.99, epsilon=1e-8, use_locking=False, name="AMSGrad"):
|
|
"""Construct a new Adam optimizer."""
|
|
super(AMSGrad, self).__init__(use_locking, name)
|
|
self._lr = learning_rate
|
|
self._beta1 = beta1
|
|
self._beta2 = beta2
|
|
self._epsilon = epsilon
|
|
|
|
self._lr_t = None
|
|
self._beta1_t = None
|
|
self._beta2_t = None
|
|
self._epsilon_t = None
|
|
|
|
self._beta1_power = None
|
|
self._beta2_power = None
|
|
|
|
def _create_slots(self, var_list):
|
|
first_var = min(var_list, key=lambda x: x.name)
|
|
|
|
create_new = self._beta1_power is None
|
|
if not create_new and context.in_graph_mode():
|
|
create_new = (self._beta1_power.graph is not first_var.graph)
|
|
|
|
if create_new:
|
|
with ops.colocate_with(first_var):
|
|
self._beta1_power = variable_scope.variable(self._beta1, name="beta1_power", trainable=False)
|
|
self._beta2_power = variable_scope.variable(self._beta2, name="beta2_power", trainable=False)
|
|
# Create slots for the first and second moments.
|
|
for v in var_list:
|
|
self._zeros_slot(v, "m", self._name)
|
|
self._zeros_slot(v, "v", self._name)
|
|
self._zeros_slot(v, "vhat", self._name)
|
|
|
|
def _prepare(self):
|
|
self._lr_t = ops.convert_to_tensor(self._lr)
|
|
self._beta1_t = ops.convert_to_tensor(self._beta1)
|
|
self._beta2_t = ops.convert_to_tensor(self._beta2)
|
|
self._epsilon_t = ops.convert_to_tensor(self._epsilon)
|
|
|
|
def _apply_dense(self, grad, var):
|
|
beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
|
|
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
|
|
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
|
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
|
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
|
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
|
|
|
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
|
|
|
# m_t = beta1 * m + (1 - beta1) * g_t
|
|
m = self.get_slot(var, "m")
|
|
m_scaled_g_values = grad * (1 - beta1_t)
|
|
m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking)
|
|
|
|
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
|
v = self.get_slot(var, "v")
|
|
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
|
v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking)
|
|
|
|
# amsgrad
|
|
vhat = self.get_slot(var, "vhat")
|
|
vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
|
|
v_sqrt = math_ops.sqrt(vhat_t)
|
|
|
|
var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
|
|
return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])
|
|
|
|
def _resource_apply_dense(self, grad, var):
|
|
var = var.handle
|
|
beta1_power = math_ops.cast(self._beta1_power, grad.dtype.base_dtype)
|
|
beta2_power = math_ops.cast(self._beta2_power, grad.dtype.base_dtype)
|
|
lr_t = math_ops.cast(self._lr_t, grad.dtype.base_dtype)
|
|
beta1_t = math_ops.cast(self._beta1_t, grad.dtype.base_dtype)
|
|
beta2_t = math_ops.cast(self._beta2_t, grad.dtype.base_dtype)
|
|
epsilon_t = math_ops.cast(self._epsilon_t, grad.dtype.base_dtype)
|
|
|
|
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
|
|
|
# m_t = beta1 * m + (1 - beta1) * g_t
|
|
m = self.get_slot(var, "m").handle
|
|
m_scaled_g_values = grad * (1 - beta1_t)
|
|
m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking)
|
|
|
|
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
|
v = self.get_slot(var, "v").handle
|
|
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
|
v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking)
|
|
|
|
# amsgrad
|
|
vhat = self.get_slot(var, "vhat").handle
|
|
vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
|
|
v_sqrt = math_ops.sqrt(vhat_t)
|
|
|
|
var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
|
|
return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])
|
|
|
|
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
|
|
beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
|
|
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
|
|
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
|
|
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
|
|
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
|
|
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
|
|
|
|
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
|
|
|
|
# m_t = beta1 * m + (1 - beta1) * g_t
|
|
m = self.get_slot(var, "m")
|
|
m_scaled_g_values = grad * (1 - beta1_t)
|
|
m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
|
|
with ops.control_dependencies([m_t]):
|
|
m_t = scatter_add(m, indices, m_scaled_g_values)
|
|
|
|
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
|
|
v = self.get_slot(var, "v")
|
|
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
|
|
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
|
|
with ops.control_dependencies([v_t]):
|
|
v_t = scatter_add(v, indices, v_scaled_g_values)
|
|
|
|
# amsgrad
|
|
vhat = self.get_slot(var, "vhat")
|
|
vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
|
|
v_sqrt = math_ops.sqrt(vhat_t)
|
|
var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
|
|
return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])
|
|
|
|
def _apply_sparse(self, grad, var):
|
|
return self._apply_sparse_shared(
|
|
grad.values,
|
|
var,
|
|
grad.indices,
|
|
lambda x, i, v: state_ops.
|
|
scatter_add( # pylint: disable=g-long-lambda
|
|
x, i, v, use_locking=self._use_locking
|
|
)
|
|
)
|
|
|
|
def _resource_scatter_add(self, x, i, v):
|
|
with ops.control_dependencies([resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
|
|
return x.value()
|
|
|
|
def _resource_apply_sparse(self, grad, var, indices):
|
|
return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add)
|
|
|
|
def _finish(self, update_ops, name_scope):
|
|
# Update the power accumulators.
|
|
with ops.control_dependencies(update_ops):
|
|
with ops.colocate_with(self._beta1_power):
|
|
update_beta1 = self._beta1_power.assign(
|
|
self._beta1_power * self._beta1_t, use_locking=self._use_locking
|
|
)
|
|
update_beta2 = self._beta2_power.assign(
|
|
self._beta2_power * self._beta2_t, use_locking=self._use_locking
|
|
)
|
|
return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], name=name_scope)
|