@@ -29,9 +29,12 @@ class TensorBoardLogger(LightningLoggerBase):
29
29
30
30
Args:
31
31
save_dir (str): Save directory
32
- name (str): Experiment name. Defaults to "default".
33
- version (int): Experiment version. If version is not specified the logger inspects the save
34
- directory for existing versions, then automatically assigns the next available version.
32
+ name (str): Experiment name. Defaults to "default". If it is the empty string then no per-experiment
33
+ subdirectory is used.
34
+ version (int|str): Experiment version. If version is not specified the logger inspects the save
35
+ directory for existing versions, then automatically assigns the next available version.
36
+ If it is a string then it is used as the run-specific subdirectory name,
37
+ otherwise version_${version} is used.
35
38
\**kwargs (dict): Other arguments are passed directly to the :class:`SummaryWriter` constructor.
36
39
37
40
"""
@@ -47,6 +50,30 @@ def __init__(self, save_dir, name="default", version=None, **kwargs):
47
50
self .tags = {}
48
51
self .kwargs = kwargs
49
52
53
+ @property
54
+ def root_dir (self ):
55
+ """
56
+ Parent directory for all tensorboard checkpoint subdirectories.
57
+ If the experiment name parameter is None or the empty string, no experiment subdirectory is used
58
+ and checkpoint will be saved in save_dir/version_dir
59
+ """
60
+ if self .name is None or len (self .name ) == 0 :
61
+ return self .save_dir
62
+ else :
63
+ return os .path .join (self .save_dir , self .name )
64
+
65
+ @property
66
+ def log_dir (self ):
67
+ """
68
+ The directory for this run's tensorboard checkpoint. By default, it is named 'version_${self.version}'
69
+ but it can be overridden by passing a string value for the constructor's version parameter
70
+ instead of None or an int
71
+ """
72
+ # create a pseudo standard path ala test-tube
73
+ version = self .version if isinstance (self .version , str ) else f"version_{ self .version } "
74
+ log_dir = os .path .join (self .root_dir , version )
75
+ return log_dir
76
+
50
77
@property
51
78
def experiment (self ):
52
79
r"""
@@ -61,10 +88,8 @@ def experiment(self):
61
88
if self ._experiment is not None :
62
89
return self ._experiment
63
90
64
- root_dir = os .path .join (self .save_dir , self .name )
65
- os .makedirs (root_dir , exist_ok = True )
66
- log_dir = os .path .join (root_dir , "version_" + str (self .version ))
67
- self ._experiment = SummaryWriter (log_dir = log_dir , ** self .kwargs )
91
+ os .makedirs (self .root_dir , exist_ok = True )
92
+ self ._experiment = SummaryWriter (log_dir = self .log_dir , ** self .kwargs )
68
93
return self ._experiment
69
94
70
95
@rank_zero_only
@@ -108,8 +133,7 @@ def save(self):
108
133
# you are using PT version (<v1.2) which does not have implemented flush
109
134
self .experiment ._get_file_writer ().flush ()
110
135
111
- # create a preudo standard path ala test-tube
112
- dir_path = os .path .join (self .save_dir , self .name , 'version_%s' % self .version )
136
+ dir_path = self .log_dir
113
137
if not os .path .isdir (dir_path ):
114
138
dir_path = self .save_dir
115
139
0 commit comments