schedule fix
This commit is contained in:
parent
b64cc63a77
commit
91a073f804
|
@ -38,11 +38,12 @@ class LRSchedule(object):
|
|||
:param kw:
|
||||
"""
|
||||
super(LRSchedule, self).__init__(**kw)
|
||||
self.warmup, self.t_total = warmup, t_total
|
||||
if t_total <= 0:
|
||||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
warmup = max(warmup, 0)
|
||||
self.warmup, self.t_total = warmup, t_total
|
||||
self.warned_for_t_total_at_progress = -1
|
||||
|
||||
def get_lr(self, step, nowarn=False):
|
||||
|
@ -51,6 +52,8 @@ class LRSchedule(object):
|
|||
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
|
||||
:return: learning rate multiplier for current update
|
||||
"""
|
||||
if self.t_total < 0:
|
||||
return 1.
|
||||
progress = step / self.t_total
|
||||
ret = self.get_lr_(progress)
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
|
@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule):
|
|||
self.cycles = cycles
|
||||
|
||||
def get_lr_(self, progress):
|
||||
""" get learning rate multiplier """
|
||||
if self.t_total <= 0:
|
||||
return 1.
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
else:
|
||||
|
@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
|
|||
assert(cycles >= 1.)
|
||||
|
||||
def get_lr_(self, progress):
|
||||
if self.t_total <= 0:
|
||||
return 1.
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
else:
|
||||
|
@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
|
|||
"""
|
||||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
||||
assert(warmup * cycles < 1.)
|
||||
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup*cycles, t_total=t_total, cycles=cycles, **kw)
|
||||
warmup = warmup * cycles if warmup >= 0 else warmup
|
||||
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
||||
|
||||
def get_lr_(self, progress):
|
||||
if self.t_total <= 0.:
|
||||
return 1.
|
||||
progress = progress * self.cycles % 1.
|
||||
if progress < self.warmup:
|
||||
return progress / self.warmup
|
||||
|
@ -174,7 +171,7 @@ class BertAdam(Optimizer):
|
|||
lr: learning rate
|
||||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
||||
t_total: total number of training steps for the learning
|
||||
rate schedule, -1 means constant learning rate. Default: -1
|
||||
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
|
||||
schedule: schedule to use for the warmup (see above).
|
||||
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
|
||||
Default: 'warmup_linear'
|
||||
|
|
|
@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase):
|
|||
|
||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||
def test_it(self):
|
||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1, cycles=5)
|
||||
x = np.arange(0, 1000) / 1000
|
||||
y = [m.get_lr_(xe) for xe in x]
|
||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=-1, t_total=500, cycles=5)
|
||||
x = np.arange(0, 1000)
|
||||
y = [m.get_lr(xe) for xe in x]
|
||||
plt.plot(y)
|
||||
plt.show(block=False)
|
||||
y = np.asarray(y)
|
||||
|
|
Loading…
Reference in New Issue