Skip to content

Commit 8fa802e

Browse files
authored
Tensorboard path generalisation (#804)
* Allow experiment versions to be overridden by passing a string value. Allow experiment names to be empty, in which case no per-experiment subdirectory will be created and checkpoints will be saved in the directory given by the save_dir parameter. * Document tensorboard api changes * Review comment fixes plus fixed test failure for minimum requirements build * More format fixes from review
1 parent fc0ad03 commit 8fa802e

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

pytorch_lightning/loggers/tensorboard.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ class TensorBoardLogger(LightningLoggerBase):
2929
3030
Args:
3131
save_dir (str): Save directory
32-
name (str): Experiment name. Defaults to "default".
33-
version (int): Experiment version. If version is not specified the logger inspects the save
34-
directory for existing versions, then automatically assigns the next available version.
32+
name (str): Experiment name. Defaults to "default". If it is the empty string then no per-experiment
33+
subdirectory is used.
34+
version (int|str): Experiment version. If version is not specified the logger inspects the save
35+
directory for existing versions, then automatically assigns the next available version.
36+
If it is a string then it is used as the run-specific subdirectory name,
37+
otherwise version_${version} is used.
3538
\**kwargs (dict): Other arguments are passed directly to the :class:`SummaryWriter` constructor.
3639
3740
"""
@@ -47,6 +50,30 @@ def __init__(self, save_dir, name="default", version=None, **kwargs):
4750
self.tags = {}
4851
self.kwargs = kwargs
4952

53+
@property
54+
def root_dir(self):
55+
"""
56+
Parent directory for all tensorboard checkpoint subdirectories.
57+
If the experiment name parameter is None or the empty string, no experiment subdirectory is used
58+
and checkpoint will be saved in save_dir/version_dir
59+
"""
60+
if self.name is None or len(self.name) == 0:
61+
return self.save_dir
62+
else:
63+
return os.path.join(self.save_dir, self.name)
64+
65+
@property
66+
def log_dir(self):
67+
"""
68+
The directory for this run's tensorboard checkpoint. By default, it is named 'version_${self.version}'
69+
but it can be overridden by passing a string value for the constructor's version parameter
70+
instead of None or an int
71+
"""
72+
# create a pseudo standard path ala test-tube
73+
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
74+
log_dir = os.path.join(self.root_dir, version)
75+
return log_dir
76+
5077
@property
5178
def experiment(self):
5279
r"""
@@ -61,10 +88,8 @@ def experiment(self):
6188
if self._experiment is not None:
6289
return self._experiment
6390

64-
root_dir = os.path.join(self.save_dir, self.name)
65-
os.makedirs(root_dir, exist_ok=True)
66-
log_dir = os.path.join(root_dir, "version_" + str(self.version))
67-
self._experiment = SummaryWriter(log_dir=log_dir, **self.kwargs)
91+
os.makedirs(self.root_dir, exist_ok=True)
92+
self._experiment = SummaryWriter(log_dir=self.log_dir, **self.kwargs)
6893
return self._experiment
6994

7095
@rank_zero_only
@@ -108,8 +133,7 @@ def save(self):
108133
# you are using PT version (<v1.2) which does not have implemented flush
109134
self.experiment._get_file_writer().flush()
110135

111-
# create a preudo standard path ala test-tube
112-
dir_path = os.path.join(self.save_dir, self.name, 'version_%s' % self.version)
136+
dir_path = self.log_dir
113137
if not os.path.isdir(dir_path):
114138
dir_path = self.save_dir
115139

tests/test_logging.py

+13
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,19 @@ def test_tensorboard_manual_versioning(tmpdir):
294294
assert logger.version == 1
295295

296296

297+
def test_tensorboard_named_version(tmpdir):
298+
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """
299+
300+
tmpdir.mkdir("tb_versioning")
301+
expected_version = "2020-02-05-162402"
302+
303+
logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=expected_version)
304+
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
305+
306+
assert logger.version == expected_version
307+
# Could also test existence of the directory but this fails in the "minimum requirements" test setup
308+
309+
297310
@pytest.mark.parametrize("step_idx", [10, None])
298311
def test_tensorboard_log_metrics(tmpdir, step_idx):
299312
logger = TensorBoardLogger(tmpdir)

0 commit comments

Comments
 (0)