-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathtest_optimizers.py
146 lines (111 loc) · 4.47 KB
/
test_optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import math
import os
import pytest
import torch
import tests.models.utils as tutils
from pytorch_lightning import Trainer
from tests.models import (
TestModelBase,
LightTrainDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
)
def test_optimizer_with_scheduling(tmpdir):
""" Verify that learning rate scheduling is working """
tutils.reset_seed()
class CurrentTestModel(
LightTestOptimizerWithSchedulingMixin,
LightTrainDataloader,
TestModelBase):
pass
hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
# fit model
trainer = Trainer(**trainer_options)
results = trainer.fit(model)
init_lr = hparams.learning_rate
adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups]
assert len(trainer.lr_schedulers) == 1, \
'lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers)
assert all(a == adjusted_lr[0] for a in adjusted_lr), \
'Lr not equally adjusted for all param groups'
adjusted_lr = adjusted_lr[0]
assert init_lr * 0.1 == adjusted_lr, \
'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr)
def test_multi_optimizer_with_scheduling(tmpdir):
""" Verify that learning rate scheduling is working """
tutils.reset_seed()
class CurrentTestModel(
LightTestMultipleOptimizersWithSchedulingMixin,
LightTrainDataloader,
TestModelBase):
pass
hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
# fit model
trainer = Trainer(**trainer_options)
results = trainer.fit(model)
init_lr = hparams.learning_rate
adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups]
adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups]
assert len(trainer.lr_schedulers) == 2, \
'all lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers)
assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \
'Lr not equally adjusted for all param groups for optimizer 1'
adjusted_lr1 = adjusted_lr1[0]
assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \
'Lr not equally adjusted for all param groups for optimizer 2'
adjusted_lr2 = adjusted_lr2[0]
assert init_lr * 0.1 == adjusted_lr1 and init_lr * 0.1 == adjusted_lr2, \
'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr1)
def test_multi_optimizer_with_scheduling_stepping(tmpdir):
tutils.reset_seed()
class CurrentTestModel(
LightTestOptimizersWithMixedSchedulingMixin,
LightTrainDataloader,
TestModelBase):
pass
hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
# fit model
trainer = Trainer(**trainer_options)
results = trainer.fit(model)
init_lr = hparams.learning_rate
adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups]
adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups]
assert len(trainer.lr_schedulers) == 2, \
'all lr scheduler not initialized properly'
assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \
'lr not equally adjusted for all param groups for optimizer 1'
adjusted_lr1 = adjusted_lr1[0]
assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \
'lr not equally adjusted for all param groups for optimizer 2'
adjusted_lr2 = adjusted_lr2[0]
# Called ones after end of epoch
assert init_lr * (0.1)**3 == adjusted_lr1, \
'lr for optimizer 1 not adjusted correctly'
# Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times
assert init_lr * 0.1 == adjusted_lr2, \
'lr for optimizer 2 not adjusted correctly'