Skip to content

Commit 1a185ca

Browse files
ethanwharristullie
authored andcommitted
Logger tests and fixes (Lightning-AI#1009)
* Refactor logger tests * Update and add tests for wandb logger * Update and add tests for logger bases * Update and add tests for mlflow logger * Improve coverage * Updates * Update CHANGELOG * Updates * Fix style * Fix style * Updates
1 parent 18b5172 commit 1a185ca

17 files changed

+767
-514
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4747
- Fixed a bug where the model checkpointer didn't write to the same directory as the logger ([#771](https://github.com/PyTorchLightning/pytorch-lightning/pull/771))
4848
- Fixed a bug where the `TensorBoardLogger` class would create an additional empty log file during fitting ([#777](https://github.com/PyTorchLightning/pytorch-lightning/pull/777))
4949
- Fixed a bug where `global_step` was advanced incorrectly when using `accumulate_grad_batches > 1` ([#832](https://github.com/PyTorchLightning/pytorch-lightning/pull/832))
50+
- Fixed a bug when calling `self.logger.experiment` with multiple loggers ([#1009](https://github.com/PyTorchLightning/pytorch-lightning/pull/1009))
51+
- Fixed a bug when calling `logger.append_tags` on a `NeptuneLogger` with a single tag ([#1009](https://github.com/PyTorchLightning/pytorch-lightning/pull/1009))
5052

5153
## [0.6.0] - 2020-01-21
5254

pytorch_lightning/loggers/base.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __getitem__(self, index: int) -> LightningLoggerBase:
105105

106106
@property
107107
def experiment(self) -> List[Any]:
108-
return [logger.experiment() for logger in self._logger_iterable]
108+
return [logger.experiment for logger in self._logger_iterable]
109109

110110
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
111111
[logger.log_metrics(metrics, step) for logger in self._logger_iterable]
@@ -122,11 +122,7 @@ def finalize(self, status: str):
122122
def close(self):
123123
[logger.close() for logger in self._logger_iterable]
124124

125-
@property
126-
def rank(self) -> int:
127-
return self._rank
128-
129-
@rank.setter
125+
@LightningLoggerBase.rank.setter
130126
def rank(self, value: int):
131127
self._rank = value
132128
for logger in self._logger_iterable:

pytorch_lightning/loggers/comet.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88
import argparse
99
from logging import getLogger
10-
from typing import Optional, Union, Dict
10+
from typing import Optional, Dict, Union
1111

1212
try:
1313
from comet_ml import Experiment as CometExperiment
@@ -20,8 +20,10 @@
2020
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
2121
from comet_ml.papi import API
2222
except ImportError:
23-
raise ImportError('Missing comet_ml package.')
23+
raise ImportError('You want to use `comet_ml` logger which is not installed yet,'
24+
' install it with `pip install comet-ml`.')
2425

26+
import torch
2527
from torch import is_tensor
2628

2729
from pytorch_lightning.utilities.debugging import MisconfigurationException
@@ -87,11 +89,7 @@ def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None
8789
self._experiment = None
8890

8991
# Determine online or offline mode based on which arguments were passed to CometLogger
90-
if save_dir is not None and api_key is not None:
91-
# If arguments are passed for both save_dir and api_key, preference is given to online mode
92-
self.mode = "online"
93-
self.api_key = api_key
94-
elif api_key is not None:
92+
if api_key is not None:
9593
self.mode = "online"
9694
self.api_key = api_key
9795
elif save_dir is not None:
@@ -168,7 +166,11 @@ def log_hyperparams(self, params: argparse.Namespace):
168166
self.experiment.log_parameters(vars(params))
169167

170168
@rank_zero_only
171-
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
169+
def log_metrics(
170+
self,
171+
metrics: Dict[str, Union[torch.Tensor, float]],
172+
step: Optional[int] = None
173+
):
172174
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
173175
for key, val in metrics.items():
174176
if is_tensor(val):

pytorch_lightning/loggers/mlflow.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def any_lightning_module_function_or_hook(...):
3131
try:
3232
import mlflow
3333
except ImportError:
34-
raise ImportError('Missing mlflow package.')
34+
raise ImportError('You want to use `mlflow` logger which is not installed yet,'
35+
' install it with `pip install mlflow`.')
3536

3637
from .base import LightningLoggerBase, rank_zero_only
3738

@@ -79,7 +80,7 @@ def run_id(self):
7980
if expt:
8081
self._expt_id = expt.experiment_id
8182
else:
82-
logger.warning(f"Experiment with name {self.experiment_name} not found. Creating it.")
83+
logger.warning(f'Experiment with name {self.experiment_name} not found. Creating it.')
8384
self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name)
8485

8586
run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags)
@@ -96,17 +97,15 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
9697
timestamp_ms = int(time() * 1000)
9798
for k, v in metrics.items():
9899
if isinstance(v, str):
99-
logger.warning(
100-
f"Discarding metric with string value {k}={v}"
101-
)
100+
logger.warning(f'Discarding metric with string value {k}={v}.')
102101
continue
103102
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
104103

105104
def save(self):
106105
pass
107106

108107
@rank_zero_only
109-
def finalize(self, status: str = "FINISHED"):
108+
def finalize(self, status: str = 'FINISHED'):
110109
if status == 'success':
111110
status = 'FINISHED'
112111
self.experiment.set_terminated(self.run_id, status)

pytorch_lightning/loggers/neptune.py

+27-25
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from neptune.experiments import Experiment
1616
except ImportError:
1717
raise ImportError('You want to use `neptune` logger which is not installed yet,'
18-
' please install it e.g. `pip install neptune-client`.')
18+
' install it with `pip install neptune-client`.')
1919

20+
import torch
2021
from torch import is_tensor
2122

22-
# from .base import LightningLoggerBase, rank_zero_only
2323
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only
2424

2525
logger = getLogger(__name__)
@@ -130,15 +130,15 @@ def any_lightning_module_function_or_hook(...):
130130
self._kwargs = kwargs
131131

132132
if offline_mode:
133-
self.mode = "offline"
133+
self.mode = 'offline'
134134
neptune.init(project_qualified_name='dry-run/project',
135135
backend=neptune.OfflineBackend())
136136
else:
137-
self.mode = "online"
137+
self.mode = 'online'
138138
neptune.init(api_token=self.api_key,
139139
project_qualified_name=self.project_name)
140140

141-
logger.info(f"NeptuneLogger was initialized in {self.mode} mode")
141+
logger.info(f'NeptuneLogger was initialized in {self.mode} mode')
142142

143143
@property
144144
def experiment(self) -> Experiment:
@@ -166,53 +166,58 @@ def experiment(self) -> Experiment:
166166
@rank_zero_only
167167
def log_hyperparams(self, params: argparse.Namespace):
168168
for key, val in vars(params).items():
169-
self.experiment.set_property(f"param__{key}", val)
169+
self.experiment.set_property(f'param__{key}', val)
170170

171171
@rank_zero_only
172-
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
172+
def log_metrics(
173+
self,
174+
metrics: Dict[str, Union[torch.Tensor, float]],
175+
step: Optional[int] = None
176+
):
173177
"""Log metrics (numeric values) in Neptune experiments
174178
175179
Args:
176180
metrics: Dictionary with metric names as keys and measured quantities as values
177181
step: Step number at which the metrics should be recorded, must be strictly increasing
178182
"""
179-
180183
for key, val in metrics.items():
181-
if is_tensor(val):
182-
val = val.cpu().detach()
183-
184-
if step is None:
185-
self.experiment.log_metric(key, val)
186-
else:
187-
self.experiment.log_metric(key, x=step, y=val)
184+
self.log_metric(key, val, step=step)
188185

189186
@rank_zero_only
190187
def finalize(self, status: str):
191188
self.experiment.stop()
192189

193190
@property
194191
def name(self) -> str:
195-
if self.mode == "offline":
196-
return "offline-name"
192+
if self.mode == 'offline':
193+
return 'offline-name'
197194
else:
198195
return self.experiment.name
199196

200197
@property
201198
def version(self) -> str:
202-
if self.mode == "offline":
203-
return "offline-id-1234"
199+
if self.mode == 'offline':
200+
return 'offline-id-1234'
204201
else:
205202
return self.experiment.id
206203

207204
@rank_zero_only
208-
def log_metric(self, metric_name: str, metric_value: float, step: Optional[int] = None):
205+
def log_metric(
206+
self,
207+
metric_name: str,
208+
metric_value: Union[torch.Tensor, float, str],
209+
step: Optional[int] = None
210+
):
209211
"""Log metrics (numeric values) in Neptune experiments
210212
211213
Args:
212214
metric_name: The name of log, i.e. mse, loss, accuracy.
213215
metric_value: The value of the log (data-point).
214216
step: Step number at which the metrics should be recorded, must be strictly increasing
215217
"""
218+
if is_tensor(metric_value):
219+
metric_value = metric_value.cpu().detach()
220+
216221
if step is None:
217222
self.experiment.log_metric(metric_name, metric_value)
218223
else:
@@ -227,10 +232,7 @@ def log_text(self, log_name: str, text: str, step: Optional[int] = None):
227232
text: The value of the log (data-point).
228233
step: Step number at which the metrics should be recorded, must be strictly increasing
229234
"""
230-
if step is None:
231-
self.experiment.log_metric(log_name, text)
232-
else:
233-
self.experiment.log_metric(log_name, x=step, y=text)
235+
self.log_metric(log_name, text, step=step)
234236

235237
@rank_zero_only
236238
def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None):
@@ -277,6 +279,6 @@ def append_tags(self, tags: Union[str, Iterable[str]]):
277279
If multiple - comma separated - str are passed, all of them are added as tags.
278280
If list of str is passed, all elements of the list are added as tags.
279281
"""
280-
if not isinstance(tags, Iterable):
282+
if str(tags) == tags:
281283
tags = [tags] # make it as an iterable is if it is not yet
282284
self.experiment.append_tags(*tags)

pytorch_lightning/loggers/tensorboard.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ class TensorBoardLogger(LightningLoggerBase):
4444
"""
4545
NAME_CSV_TAGS = 'meta_tags.csv'
4646

47-
def __init__(self, save_dir: str, name: str = "default", version: Optional[Union[int, str]] = None, **kwargs):
47+
def __init__(
48+
self, save_dir: str, name: Optional[str] = "default",
49+
version: Optional[Union[int, str]] = None, **kwargs
50+
):
4851
super().__init__()
4952
self.save_dir = save_dir
5053
self._name = name

pytorch_lightning/loggers/test_tube.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
try:
55
from test_tube import Experiment
66
except ImportError:
7-
raise ImportError('Missing test-tube package.')
7+
raise ImportError('You want to use `test_tube` logger which is not installed yet,'
8+
' install it with `pip install test-tube`.')
89

910
from .base import LightningLoggerBase, rank_zero_only
1011

pytorch_lightning/loggers/wandb.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import os
1010
from typing import Optional, List, Dict
1111

12+
import torch.nn as nn
13+
1214
try:
1315
import wandb
1416
from wandb.wandb_run import Run
1517
except ImportError:
1618
raise ImportError('You want to use `wandb` logger which is not installed yet,'
17-
' please install it e.g. `pip install wandb`.')
19+
' install it with `pip install wandb`.')
1820

1921
from .base import LightningLoggerBase, rank_zero_only
2022

@@ -50,7 +52,7 @@ def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None,
5052
super().__init__()
5153
self._name = name
5254
self._save_dir = save_dir
53-
self._anonymous = "allow" if anonymous else None
55+
self._anonymous = 'allow' if anonymous else None
5456
self._id = version or id
5557
self._tags = tags
5658
self._project = project
@@ -79,27 +81,25 @@ def experiment(self) -> Run:
7981
"""
8082
if self._experiment is None:
8183
if self._offline:
82-
os.environ["WANDB_MODE"] = "dryrun"
84+
os.environ['WANDB_MODE'] = 'dryrun'
8385
self._experiment = wandb.init(
8486
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
85-
id=self._id, resume="allow", tags=self._tags, entity=self._entity)
87+
id=self._id, resume='allow', tags=self._tags, entity=self._entity)
8688
return self._experiment
8789

88-
def watch(self, model, log="gradients", log_freq=100):
89-
wandb.watch(model, log, log_freq)
90+
def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
91+
wandb.watch(model, log=log, log_freq=log_freq)
9092

9193
@rank_zero_only
9294
def log_hyperparams(self, params: argparse.Namespace):
9395
self.experiment.config.update(params)
9496

9597
@rank_zero_only
9698
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
97-
metrics["global_step"] = step
99+
if step is not None:
100+
metrics['global_step'] = step
98101
self.experiment.log(metrics)
99102

100-
def save(self):
101-
pass
102-
103103
@rank_zero_only
104104
def finalize(self, status: str = 'success'):
105105
try:

tests/loggers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)