Skip to content

Commit a8ada6a

Browse files
committed
use gfile to support remote directories
Tests all use the `tmpfile` fixture which provides a py.path.local which is incompatible with the compat.gfile. The contract in many places is type str or Optional[str] which py.path.local is not. I hope that folks are not passing in path.local objects, if so this change will break them. The type annotations say to use str, so this should be ok. The other option is to just explicitly convert to str as to not break people using an incorrect type (like the tests were doing)
1 parent c826a5f commit a8ada6a

File tree

7 files changed

+65
-32
lines changed

7 files changed

+65
-32
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning import _logger as log
1717
from pytorch_lightning.callbacks.base import Callback
1818
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
19+
from pytorch_lightning.utilities.io import gfile
1920

2021

2122
class ModelCheckpoint(Callback):
@@ -97,7 +98,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
9798
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
9899
mode: str = 'auto', period: int = 1, prefix: str = ''):
99100
super().__init__()
100-
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
101+
if(filepath):
102+
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
103+
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
101104
rank_zero_warn(
102105
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
103106
"All files in this directory will be deleted when a checkpoint is saved!"
@@ -109,12 +112,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
109112
if filepath is None: # will be determined by trainer at runtime
110113
self.dirpath, self.filename = None, None
111114
else:
112-
if os.path.isdir(filepath):
115+
if gfile.isdir(filepath):
113116
self.dirpath, self.filename = filepath, '{epoch}'
114117
else:
115118
filepath = os.path.realpath(filepath)
116119
self.dirpath, self.filename = os.path.split(filepath)
117-
os.makedirs(self.dirpath, exist_ok=True)
120+
if not gfile.exists(self.dirpath):
121+
gfile.makedirs(self.dirpath)
118122
self.save_last = save_last
119123
self.save_top_k = save_top_k
120124
self.save_weights_only = save_weights_only
@@ -156,12 +160,19 @@ def kth_best_model(self):
156160
return self.kth_best_model_path
157161

158162
def _del_model(self, filepath):
159-
if os.path.isfile(filepath):
160-
os.remove(filepath)
163+
if gfile.exists(filepath):
164+
try:
165+
# in compat mode, remove is not implemented so if running this
166+
# against an actual remove file system and the correct remote
167+
# dependencies exist then this will work fine.
168+
gfile.remove(filepath)
169+
except AttributeError:
170+
os.remove(filepath)
161171

162172
def _save_model(self, filepath):
163173
# make paths
164-
os.makedirs(os.path.dirname(filepath), exist_ok=True)
174+
if not gfile.exists(os.path.dirname(filepath)):
175+
gfile.makedirs(os.path.dirname(filepath))
165176

166177
# delegate the saving to the model
167178
if self.save_function is not None:
@@ -249,7 +260,7 @@ def on_validation_end(self, trainer, pl_module):
249260

250261
filepath = self.format_checkpoint_name(epoch, metrics)
251262
version_cnt = 0
252-
while os.path.isfile(filepath):
263+
while gfile.exists(filepath):
253264
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
254265
# this epoch called before
255266
version_cnt += 1

pytorch_lightning/core/saving.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
1313
from pytorch_lightning.utilities.io import load as pl_load
1414

15+
# we want this for tf.iogfile, which if tf is installed gives full tf,
16+
# otherwise gives a pruned down version which works for some file backends but
17+
# not all
18+
from tensorboard.compat import tf
19+
1520
PRIMITIVE_TYPES = (bool, int, float, str)
1621
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
1722
try:
@@ -269,25 +274,25 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
269274
True
270275
>>> os.remove(path_csv)
271276
"""
272-
if not os.path.isfile(tags_csv):
277+
if not tf.io.gfile.exists(tags_csv):
273278
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
274279
return {}
275280

276-
with open(tags_csv) as fp:
281+
with tf.io.gfile.GFile(tags_csv, "r") as fp:
277282
csv_reader = csv.reader(fp, delimiter=',')
278283
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
279284

280285
return tags
281286

282287

283288
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
284-
if not os.path.isdir(os.path.dirname(tags_csv)):
289+
if not tf.io.gfile.isdir(os.path.dirname(tags_csv)):
285290
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
286291

287292
if isinstance(hparams, Namespace):
288293
hparams = vars(hparams)
289294

290-
with open(tags_csv, 'w') as fp:
295+
with tf.io.gfile.GFile(tags_csv, 'w') as fp:
291296
fieldnames = ['key', 'value']
292297
writer = csv.DictWriter(fp, fieldnames=fieldnames)
293298
writer.writerow({'key': 'key', 'value': 'value'})
@@ -306,24 +311,24 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
306311
True
307312
>>> os.remove(path_yaml)
308313
"""
309-
if not os.path.isfile(config_yaml):
314+
if not tf.io.gfile.exists(config_yaml):
310315
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
311316
return {}
312317

313-
with open(config_yaml) as fp:
318+
with tf.io.gfile.GFile(config_yaml, "r") as fp:
314319
tags = yaml.load(fp, Loader=yaml.SafeLoader)
315320

316321
return tags
317322

318323

319324
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
320-
if not os.path.isdir(os.path.dirname(config_yaml)):
325+
if not tf.io.gfile.isdir(os.path.dirname(config_yaml)):
321326
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
322327

323328
if isinstance(hparams, Namespace):
324329
hparams = vars(hparams)
325330

326-
with open(config_yaml, 'w', newline='') as fp:
331+
with tf.io.gfile.GFile(config_yaml, 'w') as fp:
327332
yaml.dump(hparams, fp)
328333

329334

pytorch_lightning/loggers/tensorboard.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytorch_lightning.core.saving import save_hparams_to_yaml
1818
from pytorch_lightning.loggers.base import LightningLoggerBase
1919
from pytorch_lightning.utilities import rank_zero_only
20+
from pytorch_lightning.utilities.io import gfile
2021

2122

2223
class TensorBoardLogger(LightningLoggerBase):
@@ -97,7 +98,8 @@ def experiment(self) -> SummaryWriter:
9798
if self._experiment is not None:
9899
return self._experiment
99100

100-
os.makedirs(self.root_dir, exist_ok=True)
101+
if not gfile.exists(self.root_dir):
102+
gfile.makedirs(self.root_dir)
101103
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
102104
return self._experiment
103105

@@ -145,7 +147,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
145147
def save(self) -> None:
146148
super().save()
147149
dir_path = self.log_dir
148-
if not os.path.isdir(dir_path):
150+
if not gfile.isdir(dir_path):
149151
dir_path = self.save_dir
150152

151153
# prepare the file path
@@ -171,13 +173,13 @@ def version(self) -> int:
171173
def _get_next_version(self):
172174
root_dir = os.path.join(self.save_dir, self.name)
173175

174-
if not os.path.isdir(root_dir):
176+
if not gfile.isdir(root_dir):
175177
log.warning('Missing logger folder: %s', root_dir)
176178
return 0
177179

178180
existing_versions = []
179-
for d in os.listdir(root_dir):
180-
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
181+
for d in gfile.listdir(root_dir):
182+
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
181183
existing_versions.append(int(d.split("_")[1]))
182184

183185
if len(existing_versions) == 0:

pytorch_lightning/trainer/callback_config.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
77
from pytorch_lightning.loggers import LightningLoggerBase
88
from pytorch_lightning.utilities.exceptions import MisconfigurationException
9+
from pytorch_lightning.utilities.io import gfile
910

1011

1112
class TrainerCallbackConfigMixin(ABC):
@@ -67,7 +68,8 @@ def configure_checkpoint_callback(self):
6768
monitor_key = 'loss' if train_step_only else 'val_loss'
6869

6970
if self.checkpoint_callback is True:
70-
os.makedirs(ckpt_path, exist_ok=True)
71+
if not gfile.exists(ckpt_path):
72+
gfile.makedirs(ckpt_path)
7173
self.checkpoint_callback = ModelCheckpoint(
7274
filepath=ckpt_path,
7375
monitor=monitor_key
@@ -77,7 +79,9 @@ def configure_checkpoint_callback(self):
7779
and self.checkpoint_callback.dirpath is None:
7880
self.checkpoint_callback.dirpath = ckpt_path
7981
self.checkpoint_callback.filename = '{epoch}'
80-
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
82+
if not gfile.exists(self.checkpoint_callback.dirpath):
83+
gfile.makedirs(self.checkpoint_callback.dirpath)
84+
8185
elif self.checkpoint_callback is False:
8286
self.checkpoint_callback = None
8387

pytorch_lightning/trainer/trainer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,11 @@ def __init__(
371371
' val and test loop using a single batch')
372372

373373
# set default save path if user didn't provide one
374-
self.default_root_dir = default_root_dir
375-
376-
if self.default_root_dir is None:
374+
if default_root_dir is None:
377375
self.default_root_dir = os.getcwd()
376+
else:
377+
# we have to do str() because the unit tests violate type annotation and pass path objecto
378+
self.default_root_dir = str(default_root_dir)
378379

379380
# training bookeeping
380381
self.total_batch_idx = 0

pytorch_lightning/trainer/training_io.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
)
103103
from pytorch_lightning.utilities import rank_zero_warn, parsing
104104
from pytorch_lightning.utilities.io import load as pl_load
105+
from pytorch_lightning.utilities.io import gfile
105106

106107
try:
107108
import torch_xla
@@ -374,9 +375,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):
374375
did_restore = False
375376

376377
# look for hpc weights
377-
folderpath = self.weights_save_path
378-
if os.path.exists(folderpath):
379-
files = os.listdir(folderpath)
378+
folderpath = str(self.weights_save_path)
379+
if gfile.exists(folderpath):
380+
files = gfile.listdir(folderpath)
380381
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
381382

382383
# if hpc weights exist restore model
@@ -451,15 +452,17 @@ def restore_training_state(self, checkpoint):
451452
# ----------------------------------
452453
def hpc_save(self, folderpath: str, logger):
453454
# make sure the checkpoint folder exists
454-
os.makedirs(folderpath, exist_ok=True)
455+
folderpath = str(folderpath) # because the tests pass a path object
456+
if not gfile.exists(folderpath):
457+
gfile.makedirs(folderpath)
455458

456459
# save logger to make sure we get all the metrics
457460
logger.save()
458461

459462
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
460463

461-
if not os.path.exists(folderpath):
462-
os.makedirs(folderpath, exist_ok=True)
464+
if not gfile.exists(folderpath):
465+
gfile.makedirs(folderpath)
463466
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
464467

465468
# give model a chance to do something on hpc_save
@@ -509,7 +512,7 @@ def hpc_load(self, folderpath, on_gpu):
509512
log.info(f'restored hpc model from: {filepath}')
510513

511514
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
512-
files = os.listdir(path)
515+
files = gfile.listdir(str(path))
513516
files = [x for x in files if name_key in x]
514517
if len(files) == 0:
515518
return 0

pytorch_lightning/utilities/io.py

+7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
from urllib.parse import urlparse
44

5+
# we want this for tf.io.gfile, which if tf is installed gives full tf,
6+
# otherwise gives a pruned down version which works for some file backends but
7+
# not all
8+
from tensorboard.compat import tf
9+
10+
gfile = tf.io.gfile
11+
512

613
def load(path_or_url: str, map_location=None):
714
parsed = urlparse(path_or_url)

0 commit comments

Comments
 (0)