Skip to content

Commit 6175d4e

Browse files
committed
Use tensorboard.compat.gfile to support remote writing
1 parent 2cc60c6 commit 6175d4e

File tree

9 files changed

+552
-321
lines changed

9 files changed

+552
-321
lines changed

environment.yml

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ dependencies:
3434
- pillow<7.0.0
3535
- scikit-image
3636
- nltk>=3.3
37+
- boto3
38+
- moto>=1.3.14
3739

3840
# Optional
3941
- scikit-learn>=0.20.0

pytorch_lightning/callbacks/model_checkpoint.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning import _logger as log
1717
from pytorch_lightning.callbacks.base import Callback
1818
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
19+
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
1920

2021

2122
class ModelCheckpoint(Callback):
@@ -100,7 +101,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
100101
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
101102
mode: str = 'auto', period: int = 1, prefix: str = ''):
102103
super().__init__()
103-
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
104+
if(filepath):
105+
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
106+
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
104107
rank_zero_warn(
105108
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
106109
"All files in this directory will be deleted when a checkpoint is saved!"
@@ -112,12 +115,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
112115
if filepath is None: # will be determined by trainer at runtime
113116
self.dirpath, self.filename = None, None
114117
else:
115-
if os.path.isdir(filepath):
118+
if gfile.isdir(filepath):
116119
self.dirpath, self.filename = filepath, '{epoch}'
117120
else:
118121
filepath = os.path.realpath(filepath)
119122
self.dirpath, self.filename = os.path.split(filepath)
120-
os.makedirs(self.dirpath, exist_ok=True)
123+
if not gfile.exists(self.dirpath):
124+
makedirs(self.dirpath)
121125
self.save_last = save_last
122126
self.save_top_k = save_top_k
123127
self.save_weights_only = save_weights_only
@@ -159,16 +163,23 @@ def kth_best_model(self):
159163
return self.kth_best_model_path
160164

161165
def _del_model(self, filepath):
162-
if os.path.isfile(filepath):
163-
os.remove(filepath)
166+
if gfile.exists(filepath):
167+
try:
168+
# in compat mode, remove is not implemented so if running this
169+
# against an actual remove file system and the correct remote
170+
# dependencies exist then this will work fine.
171+
gfile.remove(filepath)
172+
except AttributeError:
173+
os.remove(filepath)
164174

165175
def _save_model(self, filepath, trainer, pl_module):
166176

167177
# in debugging, track when we save checkpoints
168178
trainer.dev_debugger.track_checkpointing_history(filepath)
169179

170180
# make paths
171-
os.makedirs(os.path.dirname(filepath), exist_ok=True)
181+
if not gfile.exists(os.path.dirname(filepath)):
182+
makedirs(os.path.dirname(filepath))
172183

173184
# delegate the saving to the model
174185
if self.save_function is not None:
@@ -308,7 +319,7 @@ def on_validation_end(self, trainer, pl_module):
308319

309320
filepath = self.format_checkpoint_name(epoch, metrics)
310321
version_cnt = 0
311-
while os.path.isfile(filepath):
322+
while gfile.exists(filepath):
312323
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
313324
# this epoch called before
314325
version_cnt += 1

pytorch_lightning/core/saving.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ast
22
import csv
3+
import io
34
import inspect
45
import os
56

@@ -11,6 +12,7 @@
1112
from pytorch_lightning import _logger as log
1213
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
1314
from pytorch_lightning.utilities.cloud_io import load as pl_load
15+
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
1416

1517
PRIMITIVE_TYPES = (bool, int, float, str)
1618
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
@@ -273,30 +275,30 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
273275
True
274276
>>> os.remove(path_csv)
275277
"""
276-
if not os.path.isfile(tags_csv):
277-
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
278+
if not gfile.exists(tags_csv):
279+
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
278280
return {}
279281

280-
with open(tags_csv) as fp:
281-
csv_reader = csv.reader(fp, delimiter=',')
282+
with cloud_open(tags_csv, "r") as fp:
283+
csv_reader = csv.reader(fp.read(), delimiter=",")
282284
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
283285

284286
return tags
285287

286288

287289
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
288-
if not os.path.isdir(os.path.dirname(tags_csv)):
289-
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
290+
if not gfile.isdir(os.path.dirname(tags_csv)):
291+
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
290292

291293
if isinstance(hparams, Namespace):
292294
hparams = vars(hparams)
293295

294-
with open(tags_csv, 'w', newline='') as fp:
295-
fieldnames = ['key', 'value']
296+
with cloud_open(tags_csv, "w", newline="") as fp:
297+
fieldnames = ["key", "value"]
296298
writer = csv.DictWriter(fp, fieldnames=fieldnames)
297-
writer.writerow({'key': 'key', 'value': 'value'})
299+
writer.writerow({"key": "key", "value": "value"})
298300
for k, v in hparams.items():
299-
writer.writerow({'key': k, 'value': v})
301+
writer.writerow({"key": k, "value": v})
300302

301303

302304
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
@@ -310,11 +312,11 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
310312
True
311313
>>> os.remove(path_yaml)
312314
"""
313-
if not os.path.isfile(config_yaml):
314-
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
315+
if not gfile.exists(config_yaml):
316+
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
315317
return {}
316318

317-
with open(config_yaml) as fp:
319+
with cloud_open(config_yaml, "r") as fp:
318320
tags = yaml.load(fp)
319321

320322
return tags
@@ -326,11 +328,12 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
326328
config_yaml: path to new YAML file
327329
hparams: parameters to be saved
328330
"""
329-
if not os.path.isdir(os.path.dirname(config_yaml)):
330-
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
331+
if not gfile.isdir(os.path.dirname(config_yaml)):
332+
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
331333

332334
if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
333335
from omegaconf import OmegaConf
336+
334337
OmegaConf.save(hparams, config_yaml, resolve=True)
335338
return
336339

@@ -341,7 +344,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
341344
hparams = dict(hparams)
342345
assert isinstance(hparams, dict)
343346

344-
with open(config_yaml, 'w', newline='') as fp:
347+
with cloud_open(config_yaml, "w", newline="") as fp:
345348
yaml.dump(hparams, fp)
346349

347350

pytorch_lightning/loggers/tensorboard.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch_lightning.core.saving import save_hparams_to_yaml
1717
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
1818
from pytorch_lightning.utilities import rank_zero_only
19+
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
1920

2021
try:
2122
from omegaconf import Container, OmegaConf
@@ -109,7 +110,8 @@ def experiment(self) -> SummaryWriter:
109110
return self._experiment
110111

111112
assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
112-
os.makedirs(self.root_dir, exist_ok=True)
113+
if self.root_dir and not gfile.exists(str(self.root_dir)):
114+
makedirs(self.root_dir)
113115
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
114116
return self._experiment
115117

@@ -162,7 +164,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
162164
def save(self) -> None:
163165
super().save()
164166
dir_path = self.log_dir
165-
if not os.path.isdir(dir_path):
167+
if not gfile.isdir(dir_path):
166168
dir_path = self.save_dir
167169

168170
# prepare the file path
@@ -188,13 +190,13 @@ def version(self) -> int:
188190
def _get_next_version(self):
189191
root_dir = os.path.join(self.save_dir, self.name)
190192

191-
if not os.path.isdir(root_dir):
193+
if not gfile.isdir(root_dir):
192194
log.warning('Missing logger folder: %s', root_dir)
193195
return 0
194196

195197
existing_versions = []
196-
for d in os.listdir(root_dir):
197-
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
198+
for d in gfile.listdir(root_dir):
199+
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
198200
existing_versions.append(int(d.split("_")[1]))
199201

200202
if len(existing_versions) == 0:

pytorch_lightning/trainer/training_io.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
)
106106
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
107107
from pytorch_lightning.utilities.cloud_io import load as pl_load
108+
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
108109

109110
try:
110111
import torch_xla
@@ -409,9 +410,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):
409410
did_restore = False
410411

411412
# look for hpc weights
412-
folderpath = self.weights_save_path
413-
if os.path.exists(folderpath):
414-
files = os.listdir(folderpath)
413+
folderpath = str(self.weights_save_path)
414+
if gfile.exists(folderpath):
415+
files = gfile.listdir(folderpath)
415416
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
416417

417418
# if hpc weights exist restore model
@@ -490,15 +491,17 @@ def restore_training_state(self, checkpoint):
490491
# ----------------------------------
491492
def hpc_save(self, folderpath: str, logger):
492493
# make sure the checkpoint folder exists
493-
os.makedirs(folderpath, exist_ok=True)
494+
folderpath = str(folderpath) # because the tests pass a path object
495+
if not gfile.exists(folderpath):
496+
makedirs(folderpath)
494497

495498
# save logger to make sure we get all the metrics
496499
logger.save()
497500

498501
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
499502

500-
if not os.path.exists(folderpath):
501-
os.makedirs(folderpath, exist_ok=True)
503+
if not gfile.exists(folderpath):
504+
makedirs(folderpath)
502505
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
503506

504507
# give model a chance to do something on hpc_save
@@ -551,7 +554,7 @@ def hpc_load(self, folderpath, on_gpu):
551554
log.info(f'restored hpc model from: {filepath}')
552555

553556
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
554-
files = os.listdir(path)
557+
files = gfile.listdir(str(path))
555558
files = [x for x in files if name_key in x]
556559
if len(files) == 0:
557560
return 0
+49-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,56 @@
1-
import torch
2-
1+
import sys
2+
import os
3+
from typing import Union
34
from pathlib import Path
45
from urllib.parse import urlparse
6+
import torch
7+
8+
import tensorboard
9+
from packaging import version
10+
from pytorch_lightning import _logger as log
11+
12+
# we want this for tf.io.gfile, which if tf is installed gives full tf,
13+
# otherwise gives a pruned down version which works for some file backends but
14+
# not all
15+
from tensorboard.compat import tf
16+
17+
gfile = tf.io.gfile
18+
19+
pathlike = Union[Path, str]
20+
21+
# older version of tensorboard had buggy gfile compatibility layers
22+
# only support remote cloud paths if newer
23+
modern_gfile = version.parse(tensorboard.version.VERSION) >= version.parse('2.0')
524

625

726
def load(path_or_url: str, map_location=None):
827
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
928
return torch.load(path_or_url, map_location=map_location)
10-
else:
11-
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
29+
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
30+
31+
32+
def cloud_open(path: pathlike, mode: str, newline:str = None):
33+
if not modern_gfile or sys.platform == "win32":
34+
log.debug(
35+
"tenosrboard.compat gfile does not work on older versions "
36+
"of tensorboard normal local file open."
37+
)
38+
return open(path, mode, newline=newline)
39+
if sys.platform == "win32":
40+
log.debug(
41+
"gfile does not handle newlines correctly on windows so remote files are not"
42+
"supported falling back to normal local file open."
43+
)
44+
return open(path, mode, newline=newline)
45+
try:
46+
return gfile.GFile(path, mode)
47+
except NotImplementedError as e:
48+
# minimal dependencies are installed and only local files will work
49+
return open(path, mode)
50+
51+
52+
def makedirs(path: pathlike):
53+
if modern_gfile and hasattr(gfile, "makedirs"):
54+
return gfile.makedirs(str(path))
55+
# otherwise minimal dependencies are installed and only local files will work
56+
return os.makedirs(path, exist_ok=True)

requirements/base.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py
77
# pyyaml>=3.13
88
PyYAML>=5.1 # OmegaConf requirement >=5.1
99
tqdm>=4.41.0
10+
packaging

requirements/test.txt

+3
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,7 @@ black==19.10b0
1212
pre-commit>=1.0
1313

1414
cloudpickle>=1.2
15+
16+
boto3
17+
moto>=1.3.14
1518
nltk>=3.3

0 commit comments

Comments
 (0)