1
1
import ast
2
2
import csv
3
+ import io
3
4
import inspect
4
5
import os
5
6
11
12
from pytorch_lightning import _logger as log
12
13
from pytorch_lightning .utilities import rank_zero_warn , AttributeDict
13
14
from pytorch_lightning .utilities .cloud_io import load as pl_load
15
+ from pytorch_lightning .utilities .cloud_io import gfile , cloud_open
14
16
15
17
PRIMITIVE_TYPES = (bool , int , float , str )
16
18
ALLOWED_CONFIG_TYPES = (AttributeDict , MutableMapping , Namespace )
@@ -273,30 +275,30 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
273
275
True
274
276
>>> os.remove(path_csv)
275
277
"""
276
- if not os . path . isfile (tags_csv ):
277
- rank_zero_warn (f' Missing Tags: { tags_csv } .' , RuntimeWarning )
278
+ if not gfile . exists (tags_csv ):
279
+ rank_zero_warn (f" Missing Tags: { tags_csv } ." , RuntimeWarning )
278
280
return {}
279
281
280
- with open (tags_csv ) as fp :
281
- csv_reader = csv .reader (fp , delimiter = ',' )
282
+ with cloud_open (tags_csv , "r" ) as fp :
283
+ csv_reader = csv .reader (fp . read () , delimiter = "," )
282
284
tags = {row [0 ]: convert (row [1 ]) for row in list (csv_reader )[1 :]}
283
285
284
286
return tags
285
287
286
288
287
289
def save_hparams_to_tags_csv (tags_csv : str , hparams : Union [dict , Namespace ]) -> None :
288
- if not os . path .isdir (os .path .dirname (tags_csv )):
289
- raise RuntimeError (f' Missing folder: { os .path .dirname (tags_csv )} .' )
290
+ if not gfile .isdir (os .path .dirname (tags_csv )):
291
+ raise RuntimeError (f" Missing folder: { os .path .dirname (tags_csv )} ." )
290
292
291
293
if isinstance (hparams , Namespace ):
292
294
hparams = vars (hparams )
293
295
294
- with open (tags_csv , 'w' , newline = '' ) as fp :
295
- fieldnames = [' key' , ' value' ]
296
+ with cloud_open (tags_csv , "w" , newline = "" ) as fp :
297
+ fieldnames = [" key" , " value" ]
296
298
writer = csv .DictWriter (fp , fieldnames = fieldnames )
297
- writer .writerow ({' key' : ' key' , ' value' : ' value' })
299
+ writer .writerow ({" key" : " key" , " value" : " value" })
298
300
for k , v in hparams .items ():
299
- writer .writerow ({' key' : k , ' value' : v })
301
+ writer .writerow ({" key" : k , " value" : v })
300
302
301
303
302
304
def load_hparams_from_yaml (config_yaml : str ) -> Dict [str , Any ]:
@@ -310,11 +312,11 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
310
312
True
311
313
>>> os.remove(path_yaml)
312
314
"""
313
- if not os . path . isfile (config_yaml ):
314
- rank_zero_warn (f' Missing Tags: { config_yaml } .' , RuntimeWarning )
315
+ if not gfile . exists (config_yaml ):
316
+ rank_zero_warn (f" Missing Tags: { config_yaml } ." , RuntimeWarning )
315
317
return {}
316
318
317
- with open (config_yaml ) as fp :
319
+ with cloud_open (config_yaml , "r" ) as fp :
318
320
tags = yaml .load (fp )
319
321
320
322
return tags
@@ -326,11 +328,12 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
326
328
config_yaml: path to new YAML file
327
329
hparams: parameters to be saved
328
330
"""
329
- if not os . path .isdir (os .path .dirname (config_yaml )):
330
- raise RuntimeError (f' Missing folder: { os .path .dirname (config_yaml )} .' )
331
+ if not gfile .isdir (os .path .dirname (config_yaml )):
332
+ raise RuntimeError (f" Missing folder: { os .path .dirname (config_yaml )} ." )
331
333
332
334
if OMEGACONF_AVAILABLE and isinstance (hparams , Container ):
333
335
from omegaconf import OmegaConf
336
+
334
337
OmegaConf .save (hparams , config_yaml , resolve = True )
335
338
return
336
339
@@ -341,7 +344,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
341
344
hparams = dict (hparams )
342
345
assert isinstance (hparams , dict )
343
346
344
- with open (config_yaml , 'w' , newline = '' ) as fp :
347
+ with cloud_open (config_yaml , "w" , newline = "" ) as fp :
345
348
yaml .dump (hparams , fp )
346
349
347
350
0 commit comments