Skip to content

Commit

Permalink
fix load check_point bug of LinearWarmup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Oct 27, 2020
1 parent 6905608 commit 87fa3f8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
50 changes: 36 additions & 14 deletions python/paddle/fluid/tests/unittests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,19 @@ def linear_warmup_lr(epoch_num,
start_lr,
end_lr,
verbose=False):
if epoch_num < warmup_steps:
tmp = epoch_num - warmup_steps
if tmp < 0:
return start_lr + (end_lr - start_lr) * (float(epoch_num) /
float(warmup_steps))
elif paddle.in_dynamic_mode():
if tmp < 3:
return 0.5
elif tmp < 6:
return 0.2
else:
return 0.1
else:
return learning_rate
return 0.5


def multi_step_lr(epoch_num,
Expand Down Expand Up @@ -407,6 +415,9 @@ def _test_dygraph(self, python_func, paddle_api, kwarg, place):
paddle.disable_static(place)
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
if paddle_api.__name__ == "LinearWarmup":
kwarg['learning_rate'] = paddle.optimizer.lr.PiecewiseDecay(
[3, 6], [0.5, 0.2, 0.1])
scheduler = paddle_api(**kwarg)
adam = paddle.optimizer.Adam(
learning_rate=scheduler, parameters=linear.parameters())
Expand All @@ -420,12 +431,26 @@ def _test_dygraph(self, python_func, paddle_api, kwarg, place):
adam.clear_grad()
current_lr = adam.get_lr()
expected_lr = python_func(epoch, **kwarg)
if paddle_api.__name__ != "CosineAnnealingDecay":
self.assertEqual(current_lr, expected_lr)
scheduler.step()
else:
if paddle_api.__name__ == "CosineAnnealingDecay":
self.assertAlmostEqual(current_lr, expected_lr)
scheduler.step(epoch + 1)
elif paddle_api.__name__ == "LinearWarmup":
self.assertAlmostEqual(current_lr, expected_lr)
state_dict = adam.state_dict()
scheduler1 = paddle.optimizer.lr.LinearWarmup(**kwarg)
adam1 = paddle.optimizer.Adam(
learning_rate=scheduler1, parameters=linear.parameters())
adam1.set_state_dict(state_dict)
self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
self.assertEqual(scheduler.last_lr, scheduler1.last_lr)
self.assertEqual(scheduler.learning_rate.last_lr,
scheduler1.learning_rate.last_lr)
self.assertEqual(scheduler.learning_rate.last_epoch,
scheduler1.learning_rate.last_epoch)
scheduler.step()
else:
self.assertEqual(current_lr, expected_lr)
scheduler.step()

def test_scheduler(self):
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -464,8 +489,7 @@ def test_scheduler(self):
"decay_steps": 20,
"end_lr": 0,
"power": 1.0,
"cycle": False,
"verbose": True
"cycle": False
}), (polynomial_lr, paddle.optimizer.lr.PolynomialDecay, {
"learning_rate": 0.5,
"decay_steps": 20,
Expand All @@ -475,19 +499,17 @@ def test_scheduler(self):
"verbose": False
}), (linear_warmup_lr, paddle.optimizer.lr.LinearWarmup, {
'learning_rate': 0.5,
'warmup_steps': 20,
'warmup_steps': 10,
'start_lr': 0,
'end_lr': 0.5,
"verbose": True
'end_lr': 0.5
}), (exponential_lr, paddle.optimizer.lr.ExponentialDecay, {
"learning_rate": 0.5,
"gamma": 0.9,
"verbose": False
}), (multi_step_lr, paddle.optimizer.lr.MultiStepDecay, {
"learning_rate": 0.5,
"milestones": [3, 6, 9, 15, 20],
"gamma": 0.8,
"verbose": True
"gamma": 0.8
}), (step_lr, paddle.optimizer.lr.StepDecay, {
"learning_rate": 0.5,
"step_size": 2,
Expand All @@ -510,7 +532,7 @@ def test_scheduler(self):

for place in places:
paddle.enable_static()
#self._test_static(python_func, paddle_api, kwarg, place)
self._test_static(python_func, paddle_api, kwarg, place)
paddle.disable_static(place)
self._test_dygraph(python_func, paddle_api, kwarg, place)
paddle.enable_static()
Expand Down
23 changes: 21 additions & 2 deletions python/paddle/optimizer/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
last_epoch=last_epoch, verbose=verbose)

def get_lr(self):

for i in range(len(self.boundaries)):
if self.last_epoch < self.boundaries[i]:
return self.values[i]
Expand Down Expand Up @@ -750,14 +749,34 @@ def __init__(self,
end_lr, start_lr)
super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose)

def state_dict(self):
"""
Returns the state of the LinearWarmup scheduler as a :class:`dict`.
It is a subset of ``self.__dict__`` .
"""
state_dict = super(LinearWarmup, self).state_dict()
if isinstance(self.learning_rate, LRScheduler):
state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
return state_dict

def set_state_dict(self, state_dict):
"""
Loads state_dict for LinearWarmup scheduler.
"""
super(LinearWarmup, self).set_state_dict(state_dict)
if isinstance(self.learning_rate, LRScheduler):
self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

def get_lr(self):
if self.last_epoch < self.warmup_steps:
return (self.end_lr - self.start_lr) * float(
self.last_epoch) / float(self.warmup_steps) + self.start_lr
else:
if isinstance(self.learning_rate, LRScheduler):
lr_value = self.learning_rate()
self.learning_rate.step()
return self.learning_rate()
return lr_value

return self.learning_rate

Expand Down

1 comment on commit 87fa3f8

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.