Skip to content

Commit f16b4cf

Browse files
awaelchliBorda
andauthored
save_dir fix for MLflowLogger + save_dir tests for others (#2502)
* mlflow rework * logger save_dir * folder * mlflow * simplify * fix test * add a test for file dir contents * new line * changelog * docs * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> * test for comet logger * improve mlflow checkpoint test * prevent commet logger error on pytest exit * test tensorboard save dir structure * wandb save dir test * skip test on windows * add mlflow to pickle tests * wandb * code factor * remove unused imports * remove unused setter * wandb mock * wip mock * wip mock * wandb tests with mocking * clean up * clean up * comments * include wandblogger in test * clean up * missing argument Co-authored-by: Jirka Borovec <[email protected]>
1 parent 992a7e2 commit f16b4cf

13 files changed

+219
-84
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030

3131
- Made `TensorBoardLogger` and `CometLogger` pickleable ([#2518](https://github.com/PyTorchLightning/pytorch-lightning/pull/2518))
3232

33+
- Fixed a problem with `MLflowLogger` creating multiple run folders ([#2502](https://github.com/PyTorchLightning/pytorch-lightning/pull/2502))
34+
3335
## [0.8.4] - 2020-07-01
3436

3537
### Added

pytorch_lightning/callbacks/model_checkpoint.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,9 @@ def on_train_start(self, trainer, pl_module):
239239

240240
if trainer.logger is not None:
241241
# weights_save_path overrides anything
242-
if getattr(trainer, 'weights_save_path', None) is not None:
243-
save_dir = trainer.weights_save_path
244-
else:
245-
save_dir = (getattr(trainer.logger, 'save_dir', None)
246-
or getattr(trainer.logger, '_save_dir', None)
247-
or trainer.default_root_dir)
242+
save_dir = (getattr(trainer, 'weights_save_path', None)
243+
or getattr(trainer.logger, 'save_dir', None)
244+
or trainer.default_root_dir)
248245

249246
version = trainer.logger.version if isinstance(
250247
trainer.logger.version, str) else f'version_{trainer.logger.version}'

pytorch_lightning/loggers/base.py

+8
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def close(self) -> None:
237237
"""Do any cleanup that is necessary to close an experiment."""
238238
self.save()
239239

240+
@property
241+
def save_dir(self) -> Optional[str]:
242+
"""
243+
Return the root directory where experiment logs get saved, or `None` if the logger does not
244+
save data locally.
245+
"""
246+
return None
247+
240248
@property
241249
@abstractmethod
242250
def name(self) -> str:

pytorch_lightning/loggers/comet.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,7 @@ def __init__(self,
134134
self.comet_api = None
135135

136136
if experiment_name:
137-
try:
138-
self.name = experiment_name
139-
except TypeError:
140-
log.exception("Failed to set experiment name for comet.ml logger")
137+
self.experiment.set_name(experiment_name)
141138
self._kwargs = kwargs
142139

143140
@property
@@ -228,10 +225,6 @@ def save_dir(self) -> Optional[str]:
228225
def name(self) -> str:
229226
return str(self.experiment.project_name)
230227

231-
@name.setter
232-
def name(self, value: str) -> None:
233-
self.experiment.set_name(value)
234-
235228
@property
236229
def version(self) -> str:
237230
return self.experiment.id

pytorch_lightning/loggers/mlflow.py

+54-27
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
MLflow
33
------
44
"""
5-
import os
65
from argparse import Namespace
76
from time import time
87
from typing import Optional, Dict, Any, Union
@@ -11,16 +10,20 @@
1110
import mlflow
1211
from mlflow.tracking import MlflowClient
1312
_MLFLOW_AVAILABLE = True
14-
except ImportError: # pragma: no-cover
13+
except ModuleNotFoundError: # pragma: no-cover
1514
mlflow = None
1615
MlflowClient = None
1716
_MLFLOW_AVAILABLE = False
1817

18+
1919
from pytorch_lightning import _logger as log
2020
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
2121
from pytorch_lightning.utilities import rank_zero_only
2222

2323

24+
LOCAL_FILE_URI_PREFIX = "file:"
25+
26+
2427
class MLFlowLogger(LightningLoggerBase):
2528
"""
2629
Log using `MLflow <https://mlflow.org>`_. Install it with pip:
@@ -52,59 +55,71 @@ class MLFlowLogger(LightningLoggerBase):
5255
Args:
5356
experiment_name: The name of the experiment
5457
tracking_uri: Address of local or remote tracking server.
55-
If not provided, defaults to the service set by ``mlflow.tracking.set_tracking_uri``.
58+
If not provided, defaults to `file:<save_dir>`.
5659
tags: A dictionary tags for the experiment.
60+
save_dir: A path to a local directory where the MLflow runs get saved.
61+
Defaults to `./mlflow` if `tracking_uri` is not provided.
62+
Has no effect if `tracking_uri` is provided.
5763
5864
"""
5965

6066
def __init__(self,
6167
experiment_name: str = 'default',
6268
tracking_uri: Optional[str] = None,
6369
tags: Optional[Dict[str, Any]] = None,
64-
save_dir: Optional[str] = None):
70+
save_dir: Optional[str] = './mlruns'):
6571

6672
if not _MLFLOW_AVAILABLE:
6773
raise ImportError('You want to use `mlflow` logger which is not installed yet,'
6874
' install it with `pip install mlflow`.')
6975
super().__init__()
70-
if not tracking_uri and save_dir:
71-
tracking_uri = f'file:{os.sep * 2}{save_dir}'
72-
self._mlflow_client = MlflowClient(tracking_uri)
73-
self.experiment_name = experiment_name
76+
if not tracking_uri:
77+
tracking_uri = f'{LOCAL_FILE_URI_PREFIX}{save_dir}'
78+
79+
self._experiment_name = experiment_name
80+
self._experiment_id = None
81+
self._tracking_uri = tracking_uri
7482
self._run_id = None
7583
self.tags = tags
84+
self._mlflow_client = MlflowClient(tracking_uri)
7685

7786
@property
7887
@rank_zero_experiment
7988
def experiment(self) -> MlflowClient:
8089
r"""
81-
Actual MLflow object. To use mlflow features in your
90+
Actual MLflow object. To use MLflow features in your
8291
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
8392
8493
Example::
8594
8695
self.logger.experiment.some_mlflow_function()
8796
8897
"""
89-
return self._mlflow_client
90-
91-
@property
92-
def run_id(self):
93-
if self._run_id is not None:
94-
return self._run_id
95-
96-
expt = self._mlflow_client.get_experiment_by_name(self.experiment_name)
98+
expt = self._mlflow_client.get_experiment_by_name(self._experiment_name)
9799

98100
if expt:
99-
self._expt_id = expt.experiment_id
101+
self._experiment_id = expt.experiment_id
100102
else:
101-
log.warning(f'Experiment with name {self.experiment_name} not found. Creating it.')
102-
self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name)
103+
log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.')
104+
self._experiment_id = self._mlflow_client.create_experiment(name=self._experiment_name)
103105

104-
run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags)
105-
self._run_id = run.info.run_id
106+
if not self._run_id:
107+
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags)
108+
self._run_id = run.info.run_id
109+
return self._mlflow_client
110+
111+
@property
112+
def run_id(self):
113+
# create the experiment if it does not exist to get the run id
114+
_ = self.experiment
106115
return self._run_id
107116

117+
@property
118+
def experiment_id(self):
119+
# create the experiment if it does not exist to get the experiment id
120+
_ = self.experiment
121+
return self._experiment_id
122+
108123
@rank_zero_only
109124
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
110125
params = self._convert_params(params)
@@ -126,14 +141,26 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
126141
@rank_zero_only
127142
def finalize(self, status: str = 'FINISHED') -> None:
128143
super().finalize(status)
129-
if status == 'success':
130-
status = 'FINISHED'
131-
self.experiment.set_terminated(self.run_id, status)
144+
status = 'FINISHED' if status == 'success' else status
145+
if self.experiment.get_run(self.run_id):
146+
self.experiment.set_terminated(self.run_id, status)
147+
148+
@property
149+
def save_dir(self) -> Optional[str]:
150+
"""
151+
The root file directory in which MLflow experiments are saved.
152+
153+
Return:
154+
Local path to the root experiment directory if the tracking uri is local.
155+
Otherwhise returns `None`.
156+
"""
157+
if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
158+
return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)
132159

133160
@property
134161
def name(self) -> str:
135-
return self.experiment_name
162+
return self.experiment_id
136163

137164
@property
138165
def version(self) -> str:
139-
return self._run_id
166+
return self.run_id

pytorch_lightning/loggers/tensorboard.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self,
5151
**kwargs):
5252
super().__init__()
5353
self._save_dir = save_dir
54-
self._name = name
54+
self._name = name or ''
5555
self._version = version
5656

5757
self._experiment = None
@@ -106,10 +106,6 @@ def experiment(self) -> SummaryWriter:
106106
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
107107
return self._experiment
108108

109-
@experiment.setter
110-
def experiment(self, exp):
111-
self._experiment = exp
112-
113109
@rank_zero_only
114110
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
115111
metrics: Optional[Dict[str, Any]] = None) -> None:

pytorch_lightning/loggers/test_tube.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self,
6868
raise ImportError('You want to use `test_tube` logger which is not installed yet,'
6969
' install it with `pip install test-tube`.')
7070
super().__init__()
71-
self.save_dir = save_dir
71+
self._save_dir = save_dir
7272
self._name = name
7373
self.description = description
7474
self.debug = debug
@@ -141,6 +141,10 @@ def close(self) -> None:
141141
exp = self.experiment
142142
exp.close()
143143

144+
@property
145+
def save_dir(self) -> Optional[str]:
146+
return self._save_dir
147+
144148
@property
145149
def name(self) -> str:
146150
if self._experiment is None:

pytorch_lightning/loggers/wandb.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def experiment(self) -> Run:
116116
group=self._group)
117117
# save checkpoints in wandb dir to upload on W&B servers
118118
if self._log_model:
119-
self.save_dir = self._experiment.dir
119+
self._save_dir = self._experiment.dir
120120
return self._experiment
121121

122122
def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
@@ -134,13 +134,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
134134

135135
self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)
136136

137+
@property
138+
def save_dir(self) -> Optional[str]:
139+
return self._save_dir
140+
137141
@property
138142
def name(self) -> Optional[str]:
139143
# don't create an experiment if we don't have one
140-
name = self._experiment.project_name() if self._experiment else None
141-
return name
144+
return self._experiment.project_name() if self._experiment else self._name
142145

143146
@property
144147
def version(self) -> Optional[str]:
145148
# don't create an experiment if we don't have one
146-
return self._experiment.id if self._experiment else None
149+
return self._experiment.id if self._experiment else self._id

tests/loggers/test_all.py

+28-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import atexit
12
import inspect
23
import pickle
34
import platform
5+
from unittest import mock
46

57
import pytest
68

@@ -35,14 +37,15 @@ def _get_logger_args(logger_class, save_dir):
3537
MLFlowLogger,
3638
NeptuneLogger,
3739
TestTubeLogger,
38-
# WandbLogger, # TODO: add this one
40+
WandbLogger,
3941
])
40-
def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
42+
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
43+
def test_loggers_fit_test(wandb, tmpdir, monkeypatch, logger_class):
4144
"""Verify that basic functionality of all loggers."""
42-
# prevent comet logger from trying to print at exit, since
43-
# pytest's stdout/stderr redirection breaks it
44-
import atexit
45-
monkeypatch.setattr(atexit, 'register', lambda _: None)
45+
if logger_class == CometLogger:
46+
# prevent comet logger from trying to print at exit, since
47+
# pytest's stdout/stderr redirection breaks it
48+
monkeypatch.setattr(atexit, 'register', lambda _: None)
4649

4750
model = EvalModelTemplate()
4851

@@ -58,6 +61,11 @@ def log_metrics(self, metrics, step):
5861
logger_args = _get_logger_args(logger_class, tmpdir)
5962
logger = StoreHistoryLogger(**logger_args)
6063

64+
if logger_class == WandbLogger:
65+
# required mocks for Trainer
66+
logger.experiment.id = 'foo'
67+
logger.experiment.project_name.return_value = 'bar'
68+
6169
trainer = Trainer(
6270
max_epochs=1,
6371
logger=logger,
@@ -66,7 +74,6 @@ def log_metrics(self, metrics, step):
6674
fast_dev_run=True,
6775
)
6876
trainer.fit(model)
69-
7077
trainer.test()
7178

7279
log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history]
@@ -78,17 +85,17 @@ def log_metrics(self, metrics, step):
7885
@pytest.mark.parametrize("logger_class", [
7986
TensorBoardLogger,
8087
CometLogger,
81-
# MLFlowLogger,
88+
MLFlowLogger,
8289
NeptuneLogger,
8390
TestTubeLogger,
84-
# WandbLogger, # TODO: add this one
91+
# The WandbLogger gets tested for pickling in its own test.
8592
])
8693
def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
8794
"""Verify that pickling trainer with logger works."""
88-
# prevent comet logger from trying to print at exit, since
89-
# pytest's stdout/stderr redirection breaks it
90-
import atexit
91-
monkeypatch.setattr(atexit, 'register', lambda _: None)
95+
if logger_class == CometLogger:
96+
# prevent comet logger from trying to print at exit, since
97+
# pytest's stdout/stderr redirection breaks it
98+
monkeypatch.setattr(atexit, 'register', lambda _: None)
9299

93100
logger_args = _get_logger_args(logger_class, tmpdir)
94101
logger = logger_class(**logger_args)
@@ -156,12 +163,18 @@ def on_batch_start(self, trainer, pl_module):
156163
@pytest.mark.parametrize("logger_class", [
157164
TensorBoardLogger,
158165
CometLogger,
159-
# MLFlowLogger,
166+
MLFlowLogger,
160167
NeptuneLogger,
161168
TestTubeLogger,
162169
WandbLogger,
163170
])
164-
def test_logger_created_on_rank_zero_only(tmpdir, logger_class):
171+
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
172+
""" Test that loggers get replaced by dummy logges on global rank > 0"""
173+
if logger_class == CometLogger:
174+
# prevent comet logger from trying to print at exit, since
175+
# pytest's stdout/stderr redirection breaks it
176+
monkeypatch.setattr(atexit, 'register', lambda _: None)
177+
165178
logger_args = _get_logger_args(logger_class, tmpdir)
166179
logger = logger_class(**logger_args)
167180
model = EvalModelTemplate()

0 commit comments

Comments
 (0)