Skip to content

Commit 4d9d028

Browse files
committed
support newline setting in cloud_io
1 parent c77816c commit 4d9d028

File tree

2 files changed

+13
-19
lines changed

2 files changed

+13
-19
lines changed

pytorch_lightning/core/saving.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
277277
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
278278
return {}
279279

280-
with cloud_open(tags_csv, "rb") as fp:
281-
csv_reader = csv.reader(fp.read().decode("unicode_escape"), delimiter=",")
280+
with cloud_open(tags_csv, "r") as fp:
281+
csv_reader = csv.reader(fp.read(), delimiter=",")
282282
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
283283

284284
return tags
@@ -291,15 +291,12 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) ->
291291
if isinstance(hparams, Namespace):
292292
hparams = vars(hparams)
293293

294-
# write to a buffer first since cloud_open doesn't support the newline setting
295-
strbuffer = io.StringIO(newline="")
296-
fieldnames = ["key", "value"]
297-
writer = csv.DictWriter(strbuffer, fieldnames=fieldnames)
298-
writer.writerow({"key": "key", "value": "value"})
299-
for k, v in hparams.items():
300-
writer.writerow({"key": k, "value": v})
301-
with cloud_open(tags_csv, "wb") as fp:
302-
fp.write(strbuffer.getvalue().encode("unicode_escape"))
294+
with cloud_open(tags_csv, "w", newline="") as fp:
295+
fieldnames = ["key", "value"]
296+
writer = csv.DictWriter(fp, fieldnames=fieldnames)
297+
writer.writerow({"key": "key", "value": "value"})
298+
for k, v in hparams.items():
299+
writer.writerow({"key": k, "value": v})
303300

304301

305302
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
@@ -345,11 +342,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
345342
hparams = dict(hparams)
346343
assert isinstance(hparams, dict)
347344

348-
# cloud_open doesnt support newline settings so write to a buffer first
349-
strbuffer = io.StringIO(newline="")
350-
yaml.dump(hparams, strbuffer)
351-
with cloud_open(config_yaml, "w") as fp:
352-
fp.write(strbuffer.getvalue())
345+
with cloud_open(config_yaml, "w", newline="") as fp:
346+
yaml.dump(hparams, fp)
353347

354348

355349
def convert(val: str) -> Union[int, float, bool, str]:

pytorch_lightning/utilities/cloud_io.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ def load(path_or_url: str, map_location=None):
2929
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
3030

3131

32-
def cloud_open(path: pathlike, mode: str):
32+
def cloud_open(path: pathlike, mode: str, newline:str = None):
3333
if not modern_gfile or sys.platform == "win32":
3434
log.debug(
3535
"tenosrboard.compat gfile does not work on older versions "
3636
"of tensorboard normal local file open."
3737
)
38-
return open(path, mode)
38+
return open(path, mode, newline=newline)
3939
if sys.platform == "win32":
4040
log.debug(
4141
"gfile does not handle newlines correctly on windows so remote files are not"
4242
"supported falling back to normal local file open."
4343
)
44-
return open(path, mode)
44+
return open(path, mode, newline=newline)
4545
try:
4646
return gfile.GFile(path, mode)
4747
except NotImplementedError as e:

0 commit comments

Comments
 (0)