Skip to content

Commit 4dd0f91

Browse files
committed
Adress comments
1 parent bec9064 commit 4dd0f91

File tree

6 files changed

+35
-55
lines changed

6 files changed

+35
-55
lines changed

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- pytest-cov
2222
- pytest-flake8
2323
- flake8
24+
- autopep8
2425
- check-manifest
2526
- twine==1.13.0
2627
- pillow<7.0.0

pytorch_lightning/callbacks/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .callback import Callback
1+
from .base import Callback
22
from .early_stopping import EarlyStopping
33
from .model_checkpoint import ModelCheckpoint
44
from .gradient_accumulation_scheduler import GradientAccumulationScheduler

pytorch_lightning/callbacks/callback.py pytorch_lightning/callbacks/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
Callbacks supported by Lightning
66
"""
77

8+
import abc
9+
10+
811
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
912

1013

11-
class Callback(object):
14+
class Callback(abc.ABC):
1215
"""Abstract base class used to build new callbacks."""
1316

1417
def __init__(self):

pytorch_lightning/callbacks/early_stopping.py

+10-19
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55

6-
from .callback import Callback
6+
from .base import Callback
77

88

99
class EarlyStopping(Callback):
@@ -38,9 +38,8 @@ class EarlyStopping(Callback):
3838
Trainer(early_stop_callback=early_stopping)
3939
"""
4040

41-
def __init__(self, monitor='val_loss',
42-
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
43-
super(EarlyStopping, self).__init__()
41+
def __init__(self, monitor='val_loss', min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
42+
super().__init__()
4443

4544
self.monitor = monitor
4645
self.patience = patience
@@ -55,20 +54,13 @@ def __init__(self, monitor='val_loss',
5554
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
5655
mode = 'auto'
5756

58-
if mode == 'min':
59-
self.monitor_op = np.less
60-
elif mode == 'max':
61-
self.monitor_op = np.greater
62-
else:
63-
if 'acc' in self.monitor:
64-
self.monitor_op = np.greater
65-
else:
66-
self.monitor_op = np.less
67-
68-
if self.monitor_op == np.greater:
69-
self.min_delta *= 1
70-
else:
71-
self.min_delta *= -1
57+
mode_dict = {
58+
'min': np.less,
59+
'max': np.greater,
60+
'auto': np.greater if 'acc' in self.monitor else np.less
61+
}
62+
self.monitor_op = mode_dict[mode]
63+
self.min_delta *= 1 if self.monitor_op == np.greater else -1
7264

7365
self.on_train_begin()
7466

@@ -95,7 +87,6 @@ def on_train_begin(self):
9587
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
9688

9789
def on_epoch_end(self):
98-
9990
logs = self.trainer.callback_metrics
10091
stop_training = False
10192
if not self.check_metrics(logs):

pytorch_lightning/callbacks/gradient_accumulation_scheduler.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22

3-
from .callback import Callback
3+
from .base import CallbackBase
44

55

66
class GradientAccumulationScheduler(Callback):
@@ -25,10 +25,10 @@ class GradientAccumulationScheduler(Callback):
2525
def __init__(self, scheduling: dict):
2626
super().__init__()
2727

28-
if scheduling == {}: # empty dict error
28+
if not scheduling: # empty dict error
2929
raise TypeError("Empty dict cannot be interpreted correct")
3030

31-
for key in scheduling.keys():
31+
for key in scheduling:
3232
if not isinstance(key, int) or not isinstance(scheduling[key], int):
3333
raise TypeError("All epoches and accumulation factor must be integers")
3434

@@ -45,7 +45,6 @@ def __init__(self, scheduling: dict):
4545
self.epochs = sorted(scheduling.keys())
4646

4747
def on_epoch_begin(self):
48-
4948
trainer = self.trainer
5049
# indexing epochs from 1 (until v0.6.x)
5150
# In v0.8.0, ` + 1` should be removed.

pytorch_lightning/callbacks/model_checkpoint.py

+16-30
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55

66
import numpy as np
77

8-
from .callback import Callback
8+
from .base import Callback
99

1010

1111
class ModelCheckpoint(Callback):
1212
r"""
13-
1413
Save the model after every epoch.
1514
1615
Args:
@@ -27,14 +26,14 @@ class ModelCheckpoint(Callback):
2726
save_top_k (int): if `save_top_k == k`,
2827
the best k models according to
2928
the quantity monitored will be saved.
30-
if `save_top_k == 0`, no models are saved.
31-
if `save_top_k == -1`, all models are saved.
29+
if ``save_top_k == 0``, no models are saved.
30+
if ``save_top_k == -1``, all models are saved.
3231
Please note that the monitors are checked every `period` epochs.
33-
if `save_top_k >= 2` and the callback is called multiple
32+
if ``save_top_k >= 2`` and the callback is called multiple
3433
times inside an epoch, the name of the saved file will be
3534
appended with a version count starting with `v0`.
3635
mode (str): one of {auto, min, max}.
37-
If `save_top_k != 0`, the decision
36+
If ``save_top_k != 0``, the decision
3837
to overwrite the current save file is made
3938
based on either the maximization or the
4039
minimization of the monitored quantity. For `val_acc`,
@@ -60,11 +59,11 @@ class ModelCheckpoint(Callback):
6059
def __init__(self, filepath, monitor='val_loss', verbose=0,
6160
save_top_k=1, save_weights_only=False,
6261
mode='auto', period=1, prefix=''):
63-
super(ModelCheckpoint, self).__init__()
62+
super().__init__()
6463
if (
65-
save_top_k and
66-
os.path.isdir(filepath) and
67-
len(os.listdir(filepath)) > 0
64+
save_top_k
65+
and os.path.isdir(filepath)
66+
and len(os.listdir(filepath)) > 0
6867
):
6968
warnings.warn(
7069
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
@@ -111,34 +110,26 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,
111110
self.mode = 'min'
112111

113112
def _del_model(self, filepath):
114-
dirpath = os.path.dirname(filepath)
115-
116-
# make paths
117-
os.makedirs(dirpath, exist_ok=True)
118-
119113
try:
120114
shutil.rmtree(filepath)
121115
except OSError:
122116
os.remove(filepath)
123117

124118
def _save_model(self, filepath):
125-
dirpath = os.path.dirname(filepath)
126-
127119
# make paths
128-
os.makedirs(dirpath, exist_ok=True)
120+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
129121

130122
# delegate the saving to the model
131123
assert self.save_function is not None, ".save_function() not set"
132124
self.save_function(filepath)
133125

134126
def check_monitor_top_k(self, current):
135-
less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
127+
less_than_k_models = len(self.best_k_models) < self.save_top_k
136128
if less_than_k_models:
137129
return True
138130
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
139131

140132
def on_validation_end(self):
141-
142133
logs = self.trainer.callback_metrics
143134
epoch = self.trainer.current_epoch
144135
self.epochs_since_last_check += 1
@@ -174,18 +165,13 @@ def on_validation_end(self):
174165
self.best_k_models[filepath] = current
175166
if len(self.best_k_models.keys()) == self.save_top_k:
176167
# monitor dict has reached k elements
177-
if self.mode == 'min':
178-
self.kth_best_model = max(
179-
self.best_k_models, key=self.best_k_models.get)
180-
else:
181-
self.kth_best_model = min(
182-
self.best_k_models, key=self.best_k_models.get)
168+
_op = min if self.mode == 'min' else max
169+
self.kth_best_model = _op(self.best_k_models, key=self.best_k_models.get)
183170
self.kth_value = self.best_k_models[self.kth_best_model]
184171

185-
if self.mode == 'min':
186-
self.best = min(self.best_k_models.values())
187-
else:
188-
self.best = max(self.best_k_models.values())
172+
_op = min if self.mode == 'min' else max
173+
self.best = _op(self.best_k_models.values())
174+
189175
if self.verbose > 0:
190176
log.info(
191177
f'\nEpoch {epoch:05d}: {self.monitor} reached'

0 commit comments

Comments
 (0)