16
16
from pytorch_lightning import _logger as log
17
17
from pytorch_lightning .callbacks .base import Callback
18
18
from pytorch_lightning .utilities import rank_zero_warn , rank_zero_only
19
+ from pytorch_lightning .utilities .io import gfile
19
20
20
21
21
22
class ModelCheckpoint (Callback ):
@@ -97,7 +98,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
97
98
save_last : bool = False , save_top_k : int = 1 , save_weights_only : bool = False ,
98
99
mode : str = 'auto' , period : int = 1 , prefix : str = '' ):
99
100
super ().__init__ ()
100
- if save_top_k > 0 and filepath is not None and os .path .isdir (filepath ) and len (os .listdir (filepath )) > 0 :
101
+ if (filepath ):
102
+ filepath = str (filepath ) # the tests pass in a py.path.local but we want a str
103
+ if save_top_k > 0 and filepath is not None and gfile .isdir (filepath ) and len (gfile .listdir (filepath )) > 0 :
101
104
rank_zero_warn (
102
105
f"Checkpoint directory { filepath } exists and is not empty with save_top_k != 0."
103
106
"All files in this directory will be deleted when a checkpoint is saved!"
@@ -109,12 +112,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
109
112
if filepath is None : # will be determined by trainer at runtime
110
113
self .dirpath , self .filename = None , None
111
114
else :
112
- if os . path .isdir (filepath ):
115
+ if gfile .isdir (filepath ):
113
116
self .dirpath , self .filename = filepath , '{epoch}'
114
117
else :
115
118
filepath = os .path .realpath (filepath )
116
119
self .dirpath , self .filename = os .path .split (filepath )
117
- os .makedirs (self .dirpath , exist_ok = True )
120
+ if not gfile .exists (self .dirpath ):
121
+ gfile .makedirs (self .dirpath )
118
122
self .save_last = save_last
119
123
self .save_top_k = save_top_k
120
124
self .save_weights_only = save_weights_only
@@ -156,12 +160,19 @@ def kth_best_model(self):
156
160
return self .kth_best_model_path
157
161
158
162
def _del_model (self , filepath ):
159
- if os .path .isfile (filepath ):
160
- os .remove (filepath )
163
+ if gfile .exists (filepath ):
164
+ try :
165
+ # in compat mode, remove is not implemented so if running this
166
+ # against an actual remove file system and the correct remote
167
+ # dependencies exist then this will work fine.
168
+ gfile .remove (filepath )
169
+ except AttributeError :
170
+ os .remove (filepath )
161
171
162
172
def _save_model (self , filepath ):
163
173
# make paths
164
- os .makedirs (os .path .dirname (filepath ), exist_ok = True )
174
+ if not gfile .exists (os .path .dirname (filepath )):
175
+ gfile .makedirs (os .path .dirname (filepath ))
165
176
166
177
# delegate the saving to the model
167
178
if self .save_function is not None :
@@ -249,7 +260,7 @@ def on_validation_end(self, trainer, pl_module):
249
260
250
261
filepath = self .format_checkpoint_name (epoch , metrics )
251
262
version_cnt = 0
252
- while os . path . isfile (filepath ):
263
+ while gfile . exists (filepath ):
253
264
filepath = self .format_checkpoint_name (epoch , metrics , ver = version_cnt )
254
265
# this epoch called before
255
266
version_cnt += 1
0 commit comments