5
5
6
6
import numpy as np
7
7
8
- from .callback import Callback
8
+ from .base import Callback
9
9
10
10
11
11
class ModelCheckpoint (Callback ):
12
12
r"""
13
-
14
13
Save the model after every epoch.
15
14
16
15
Args:
@@ -27,14 +26,14 @@ class ModelCheckpoint(Callback):
27
26
save_top_k (int): if `save_top_k == k`,
28
27
the best k models according to
29
28
the quantity monitored will be saved.
30
- if `save_top_k == 0`, no models are saved.
31
- if `save_top_k == -1`, all models are saved.
29
+ if `` save_top_k == 0` `, no models are saved.
30
+ if `` save_top_k == -1` `, all models are saved.
32
31
Please note that the monitors are checked every `period` epochs.
33
- if `save_top_k >= 2` and the callback is called multiple
32
+ if `` save_top_k >= 2` ` and the callback is called multiple
34
33
times inside an epoch, the name of the saved file will be
35
34
appended with a version count starting with `v0`.
36
35
mode (str): one of {auto, min, max}.
37
- If `save_top_k != 0`, the decision
36
+ If `` save_top_k != 0` `, the decision
38
37
to overwrite the current save file is made
39
38
based on either the maximization or the
40
39
minimization of the monitored quantity. For `val_acc`,
@@ -60,11 +59,11 @@ class ModelCheckpoint(Callback):
60
59
def __init__ (self , filepath , monitor = 'val_loss' , verbose = 0 ,
61
60
save_top_k = 1 , save_weights_only = False ,
62
61
mode = 'auto' , period = 1 , prefix = '' ):
63
- super (ModelCheckpoint , self ).__init__ ()
62
+ super ().__init__ ()
64
63
if (
65
- save_top_k and
66
- os .path .isdir (filepath ) and
67
- len (os .listdir (filepath )) > 0
64
+ save_top_k
65
+ and os .path .isdir (filepath )
66
+ and len (os .listdir (filepath )) > 0
68
67
):
69
68
warnings .warn (
70
69
f"Checkpoint directory { filepath } exists and is not empty with save_top_k != 0."
@@ -111,34 +110,26 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,
111
110
self .mode = 'min'
112
111
113
112
def _del_model (self , filepath ):
114
- dirpath = os .path .dirname (filepath )
115
-
116
- # make paths
117
- os .makedirs (dirpath , exist_ok = True )
118
-
119
113
try :
120
114
shutil .rmtree (filepath )
121
115
except OSError :
122
116
os .remove (filepath )
123
117
124
118
def _save_model (self , filepath ):
125
- dirpath = os .path .dirname (filepath )
126
-
127
119
# make paths
128
- os .makedirs (dirpath , exist_ok = True )
120
+ os .makedirs (os . path . dirname ( filepath ) , exist_ok = True )
129
121
130
122
# delegate the saving to the model
131
123
assert self .save_function is not None , ".save_function() not set"
132
124
self .save_function (filepath )
133
125
134
126
def check_monitor_top_k (self , current ):
135
- less_than_k_models = len (self .best_k_models . keys () ) < self .save_top_k
127
+ less_than_k_models = len (self .best_k_models ) < self .save_top_k
136
128
if less_than_k_models :
137
129
return True
138
130
return self .monitor_op (current , self .best_k_models [self .kth_best_model ])
139
131
140
132
def on_validation_end (self ):
141
-
142
133
logs = self .trainer .callback_metrics
143
134
epoch = self .trainer .current_epoch
144
135
self .epochs_since_last_check += 1
@@ -174,18 +165,13 @@ def on_validation_end(self):
174
165
self .best_k_models [filepath ] = current
175
166
if len (self .best_k_models .keys ()) == self .save_top_k :
176
167
# monitor dict has reached k elements
177
- if self .mode == 'min' :
178
- self .kth_best_model = max (
179
- self .best_k_models , key = self .best_k_models .get )
180
- else :
181
- self .kth_best_model = min (
182
- self .best_k_models , key = self .best_k_models .get )
168
+ _op = min if self .mode == 'min' else max
169
+ self .kth_best_model = _op (self .best_k_models , key = self .best_k_models .get )
183
170
self .kth_value = self .best_k_models [self .kth_best_model ]
184
171
185
- if self .mode == 'min' :
186
- self .best = min (self .best_k_models .values ())
187
- else :
188
- self .best = max (self .best_k_models .values ())
172
+ _op = min if self .mode == 'min' else max
173
+ self .best = _op (self .best_k_models .values ())
174
+
189
175
if self .verbose > 0 :
190
176
log .info (
191
177
f'\n Epoch { epoch :05d} : { self .monitor } reached'
0 commit comments