Skip to content

Commit 9bbf29c

Browse files
committed
fix logic error
1 parent a435ffe commit 9bbf29c

File tree

2 files changed

+37
-33
lines changed

2 files changed

+37
-33
lines changed

pytorch_lightning/callbacks/early_stopping.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,17 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
5050
self.wait = 0
5151
self.stopped_epoch = 0
5252

53-
if mode not in mode_dict:
54-
if self.verbose > 0:
55-
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
56-
mode = 'auto'
57-
5853
mode_dict = {
5954
'min': np.less,
6055
'max': np.greater,
6156
'auto': np.greater if 'acc' in self.monitor else np.less
6257
}
58+
59+
if mode not in mode_dict:
60+
if self.verbose > 0:
61+
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
62+
mode = 'auto'
63+
6364
self.monitor_op = mode_dict[mode]
6465
self.min_delta *= 1 if self.monitor_op == np.greater else -1
6566

pytorch_lightning/callbacks/model_checkpoint.py

+31-28
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
7878
# {filename: monitor}
7979
self.kth_best_model = ''
8080
self.best = 0
81-
self.save_function = None
81+
self.save_function = lambda x: None
8282

8383
if mode not in ['auto', 'min', 'max']:
8484
warnings.warn(
@@ -115,8 +115,10 @@ def _save_model(self, filepath):
115115
os.makedirs(os.path.dirname(filepath), exist_ok=True)
116116

117117
# delegate the saving to the model
118-
assert self.save_function is not None, ".save_function() not set"
119-
self.save_function(filepath)
118+
if self.save_function is not None:
119+
self.save_function(filepath)
120+
else:
121+
raise ValueError(".save_function() not set")
120122

121123
def check_monitor_top_k(self, current):
122124
less_than_k_models = len(self.best_k_models) < self.save_top_k
@@ -150,31 +152,7 @@ def on_validation_end(self):
150152
' skipping.', RuntimeWarning)
151153
else:
152154
if self.check_monitor_top_k(current):
153-
154-
# remove kth
155-
if len(self.best_k_models) == self.save_top_k:
156-
delpath = self.kth_best_model
157-
self.best_k_models.pop(self.kth_best_model)
158-
self._del_model(delpath)
159-
160-
self.best_k_models[filepath] = current
161-
if len(self.best_k_models) == self.save_top_k:
162-
# monitor dict has reached k elements
163-
_op = max if self.mode == 'min' else min
164-
self.kth_best_model = _op(self.best_k_models,
165-
key=self.best_k_models.get)
166-
self.kth_value = self.best_k_models[self.kth_best_model]
167-
168-
_op = min if self.mode == 'min' else max
169-
self.best = _op(self.best_k_models.values())
170-
171-
if self.verbose > 0:
172-
log.info(
173-
f'\nEpoch {epoch:05d}: {self.monitor} reached'
174-
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
175-
f' {filepath} as top {self.save_top_k}')
176-
self._save_model(filepath)
177-
155+
self._do_check_save(filepath, current, epoch)
178156
else:
179157
if self.verbose > 0:
180158
log.info(
@@ -185,3 +163,28 @@ def on_validation_end(self):
185163
if self.verbose > 0:
186164
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
187165
self._save_model(filepath)
166+
167+
def _do_check_save(self, filepath, current, epoch):
168+
# remove kth
169+
if len(self.best_k_models) == self.save_top_k:
170+
delpath = self.kth_best_model
171+
self.best_k_models.pop(self.kth_best_model)
172+
self._del_model(delpath)
173+
174+
self.best_k_models[filepath] = current
175+
if len(self.best_k_models) == self.save_top_k:
176+
# monitor dict has reached k elements
177+
_op = max if self.mode == 'min' else min
178+
self.kth_best_model = _op(self.best_k_models,
179+
key=self.best_k_models.get)
180+
self.kth_value = self.best_k_models[self.kth_best_model]
181+
182+
_op = min if self.mode == 'min' else max
183+
self.best = _op(self.best_k_models.values())
184+
185+
if self.verbose > 0:
186+
log.info(
187+
f'\nEpoch {epoch:05d}: {self.monitor} reached'
188+
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
189+
f' {filepath} as top {self.save_top_k}')
190+
self._save_model(filepath)

0 commit comments

Comments
 (0)