schedule fix

This commit is contained in:
lukovnikov 2019-04-03 17:10:08 +02:00
parent b64cc63a77
commit 91a073f804
2 changed files with 10 additions and 13 deletions

View File

@ -38,11 +38,12 @@ class LRSchedule(object):
:param kw:
"""
super(LRSchedule, self).__init__(**kw)
self.warmup, self.t_total = warmup, t_total
if t_total <= 0:
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
warmup = max(warmup, 0)
self.warmup, self.t_total = warmup, t_total
self.warned_for_t_total_at_progress = -1
def get_lr(self, step, nowarn=False):
@ -51,6 +52,8 @@ class LRSchedule(object):
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
if self.t_total < 0:
return 1.
progress = step / self.t_total
ret = self.get_lr_(progress)
# warning for exceeding t_total (only active with warmup_linear
@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule):
self.cycles = cycles
def get_lr_(self, progress):
""" get learning rate multiplier """
if self.t_total <= 0:
return 1.
if progress < self.warmup:
return progress / self.warmup
else:
@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
assert(cycles >= 1.)
def get_lr_(self, progress):
if self.t_total <= 0:
return 1.
if progress < self.warmup:
return progress / self.warmup
else:
@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
"""
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
assert(warmup * cycles < 1.)
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw)
warmup = warmup * cycles if warmup >= 0 else warmup
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
def get_lr_(self, progress):
if self.t_total <= 0.:
return 1.
progress = progress * self.cycles % 1.
if progress < self.warmup:
return progress / self.warmup
@ -174,7 +171,7 @@ class BertAdam(Optimizer):
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
schedule: schedule to use for the warmup (see above).
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
Default: 'warmup_linear'

View File

@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase):
class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5)
x = np.arange(0, 1000) / 1000
y = [m.get_lr_(xe) for xe in x]
m = WarmupCosineWithWarmupRestartsSchedule(warmup=-1, t_total=500, cycles=5)
x = np.arange(0, 1000)
y = [m.get_lr(xe) for xe in x]
plt.plot(y)
plt.show(block=False)
y = np.asarray(y)