@@ -78,7 +78,7 @@ def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
78
78
# {filename: monitor}
79
79
self .kth_best_model = ''
80
80
self .best = 0
81
- self .save_function = None
81
+ self .save_function = lambda x : None
82
82
83
83
if mode not in ['auto' , 'min' , 'max' ]:
84
84
warnings .warn (
@@ -115,8 +115,10 @@ def _save_model(self, filepath):
115
115
os .makedirs (os .path .dirname (filepath ), exist_ok = True )
116
116
117
117
# delegate the saving to the model
118
- assert self .save_function is not None , ".save_function() not set"
119
- self .save_function (filepath )
118
+ if self .save_function is not None :
119
+ self .save_function (filepath )
120
+ else :
121
+ raise ValueError (".save_function() not set" )
120
122
121
123
def check_monitor_top_k (self , current ):
122
124
less_than_k_models = len (self .best_k_models ) < self .save_top_k
@@ -150,31 +152,7 @@ def on_validation_end(self):
150
152
' skipping.' , RuntimeWarning )
151
153
else :
152
154
if self .check_monitor_top_k (current ):
153
-
154
- # remove kth
155
- if len (self .best_k_models ) == self .save_top_k :
156
- delpath = self .kth_best_model
157
- self .best_k_models .pop (self .kth_best_model )
158
- self ._del_model (delpath )
159
-
160
- self .best_k_models [filepath ] = current
161
- if len (self .best_k_models ) == self .save_top_k :
162
- # monitor dict has reached k elements
163
- _op = max if self .mode == 'min' else min
164
- self .kth_best_model = _op (self .best_k_models ,
165
- key = self .best_k_models .get )
166
- self .kth_value = self .best_k_models [self .kth_best_model ]
167
-
168
- _op = min if self .mode == 'min' else max
169
- self .best = _op (self .best_k_models .values ())
170
-
171
- if self .verbose > 0 :
172
- log .info (
173
- f'\n Epoch { epoch :05d} : { self .monitor } reached'
174
- f' { current :0.5f} (best { self .best :0.5f} ), saving model to'
175
- f' { filepath } as top { self .save_top_k } ' )
176
- self ._save_model (filepath )
177
-
155
+ self ._do_check_save (filepath , current , epoch )
178
156
else :
179
157
if self .verbose > 0 :
180
158
log .info (
@@ -185,3 +163,28 @@ def on_validation_end(self):
185
163
if self .verbose > 0 :
186
164
log .info (f'\n Epoch { epoch :05d} : saving model to { filepath } ' )
187
165
self ._save_model (filepath )
166
+
167
+ def _do_check_save (self , filepath , current , epoch ):
168
+ # remove kth
169
+ if len (self .best_k_models ) == self .save_top_k :
170
+ delpath = self .kth_best_model
171
+ self .best_k_models .pop (self .kth_best_model )
172
+ self ._del_model (delpath )
173
+
174
+ self .best_k_models [filepath ] = current
175
+ if len (self .best_k_models ) == self .save_top_k :
176
+ # monitor dict has reached k elements
177
+ _op = max if self .mode == 'min' else min
178
+ self .kth_best_model = _op (self .best_k_models ,
179
+ key = self .best_k_models .get )
180
+ self .kth_value = self .best_k_models [self .kth_best_model ]
181
+
182
+ _op = min if self .mode == 'min' else max
183
+ self .best = _op (self .best_k_models .values ())
184
+
185
+ if self .verbose > 0 :
186
+ log .info (
187
+ f'\n Epoch { epoch :05d} : { self .monitor } reached'
188
+ f' { current :0.5f} (best { self .best :0.5f} ), saving model to'
189
+ f' { filepath } as top { self .save_top_k } ' )
190
+ self ._save_model (filepath )
0 commit comments