Skip to content

Commit b3fe17d

Browse files
authored
fix flushing loggers (#1459)
* flushing loggers * flushing loggers * flushing loggers * flushing loggers * changelog * typo * fix trains * optimize imports * add logger test all * add logger test pickle * flake8 * fix benchmark * hanging loggers * try * del * all * cleaning
1 parent c96c6a6 commit b3fe17d

21 files changed

+209
-334
lines changed

.github/workflows/ci-testing.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
requires: 'minimal'
2929

3030
# Timeout: https://stackoverflow.com/a/59076067/4521646
31-
timeout-minutes: 30
31+
timeout-minutes: 15
3232
steps:
3333
- uses: actions/checkout@v2
3434
- name: Set up Python ${{ matrix.python-version }}

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3434

3535
### Fixed
3636

37-
-
37+
- Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459))
3838

3939
-
4040

pytorch_lightning/loggers/base.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
`LightningLoggerBase.agg_and_log_metrics` method.
4949
"""
5050
self._rank = 0
51-
self._prev_step = -1
51+
self._prev_step: int = -1
5252
self._metrics_to_agg: List[Dict[str, float]] = []
5353
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
5454
self._agg_default_func = agg_default_func
@@ -98,15 +98,15 @@ def _aggregate_metrics(
9898
return step, None
9999

100100
# compute the metrics
101-
agg_step, agg_mets = self._finalize_agg_metrics()
101+
agg_step, agg_mets = self._reduce_agg_metrics()
102102

103103
# as new step received reset accumulator
104104
self._metrics_to_agg = [metrics]
105105
self._prev_step = step
106106
return agg_step, agg_mets
107107

108-
def _finalize_agg_metrics(self):
109-
"""Aggregate accumulated metrics. This shall be called in close."""
108+
def _reduce_agg_metrics(self):
109+
"""Aggregate accumulated metrics."""
110110
# compute the metrics
111111
if not self._metrics_to_agg:
112112
agg_mets = None
@@ -116,6 +116,14 @@ def _finalize_agg_metrics(self):
116116
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
117117
return self._prev_step, agg_mets
118118

119+
def _finalize_agg_metrics(self):
120+
"""This shall be called before save/close."""
121+
agg_step, metrics_to_log = self._reduce_agg_metrics()
122+
self._metrics_to_agg = []
123+
124+
if metrics_to_log is not None:
125+
self.log_metrics(metrics=metrics_to_log, step=agg_step)
126+
119127
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
120128
"""Aggregates and records metrics.
121129
This method doesn't log the passed metrics instantaneously, but instead
@@ -219,22 +227,19 @@ def log_hyperparams(self, params: argparse.Namespace):
219227

220228
def save(self) -> None:
221229
"""Save log data."""
222-
pass
230+
self._finalize_agg_metrics()
223231

224232
def finalize(self, status: str) -> None:
225233
"""Do any processing that is necessary to finalize an experiment.
226234
227235
Args:
228236
status: Status that the experiment finished with (e.g. success, failed, aborted)
229237
"""
230-
pass
238+
self.save()
231239

232240
def close(self) -> None:
233241
"""Do any cleanup that is necessary to close an experiment."""
234-
agg_step, metrics_to_log = self._finalize_agg_metrics()
235-
236-
if metrics_to_log is not None:
237-
self.log_metrics(metrics=metrics_to_log, step=agg_step)
242+
self.save()
238243

239244
@property
240245
def rank(self) -> int:
@@ -292,7 +297,6 @@ def close(self) -> None:
292297

293298
@LightningLoggerBase.rank.setter
294299
def rank(self, value: int) -> None:
295-
self._rank = value
296300
for logger in self._logger_iterable:
297301
logger.rank = value
298302

pytorch_lightning/loggers/comet.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,15 @@ class CometLogger(LightningLoggerBase):
3636
Log using `comet.ml <https://www.comet.ml>`_.
3737
"""
3838

39-
def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None,
40-
workspace: Optional[str] = None, project_name: Optional[str] = None,
41-
rest_api_key: Optional[str] = None, experiment_name: Optional[str] = None,
42-
experiment_key: Optional[str] = None, **kwargs):
39+
def __init__(self,
40+
api_key: Optional[str] = None,
41+
save_dir: Optional[str] = None,
42+
workspace: Optional[str] = None,
43+
project_name: Optional[str] = None,
44+
rest_api_key: Optional[str] = None,
45+
experiment_name: Optional[str] = None,
46+
experiment_key: Optional[str] = None,
47+
**kwargs):
4348
r"""
4449
4550
Requires either an API Key (online mode) or a local directory path (offline mode)
@@ -118,6 +123,7 @@ def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None
118123
self.name = experiment_name
119124
except TypeError as e:
120125
log.exception("Failed to set experiment name for comet.ml logger")
126+
self._kwargs = kwargs
121127

122128
@property
123129
def experiment(self) -> CometBaseExperiment:
@@ -197,7 +203,7 @@ def finalize(self, status: str) -> None:
197203

198204
@property
199205
def name(self) -> str:
200-
return self.experiment.project_name
206+
return str(self.experiment.project_name)
201207

202208
@name.setter
203209
def name(self, value: str) -> None:

pytorch_lightning/loggers/mlflow.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def any_lightning_module_function_or_hook(...):
2323
self.logger.experiment.whatever_ml_flow_supports(...)
2424
2525
"""
26+
import os
2627
from argparse import Namespace
2728
from time import time
2829
from typing import Optional, Dict, Any, Union
@@ -39,10 +40,14 @@ def any_lightning_module_function_or_hook(...):
3940

4041

4142
class MLFlowLogger(LightningLoggerBase):
42-
def __init__(self, experiment_name: str, tracking_uri: Optional[str] = None,
43-
tags: Dict[str, Any] = None):
44-
r"""
43+
"""MLFLow logger"""
4544

45+
def __init__(self,
46+
experiment_name: str = 'default',
47+
tracking_uri: Optional[str] = None,
48+
tags: Optional[Dict[str, Any]] = None,
49+
save_dir: Optional[str] = None):
50+
r"""
4651
Logs using MLFlow
4752
4853
Args:
@@ -51,6 +56,8 @@ def __init__(self, experiment_name: str, tracking_uri: Optional[str] = None,
5156
tags (dict): todo this param
5257
"""
5358
super().__init__()
59+
if not tracking_uri and save_dir:
60+
tracking_uri = f'file:{os.sep * 2}{save_dir}'
5461
self._mlflow_client = MlflowClient(tracking_uri)
5562
self.experiment_name = experiment_name
5663
self._run_id = None
@@ -59,7 +66,6 @@ def __init__(self, experiment_name: str, tracking_uri: Optional[str] = None,
5966
@property
6067
def experiment(self) -> MlflowClient:
6168
r"""
62-
6369
Actual mlflow object. To use mlflow features do the following.
6470
6571
Example::
@@ -102,11 +108,9 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
102108
continue
103109
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
104110

105-
def save(self):
106-
pass
107-
108111
@rank_zero_only
109112
def finalize(self, status: str = 'FINISHED') -> None:
113+
super().finalize(status)
110114
if status == 'success':
111115
status = 'FINISHED'
112116
self.experiment.set_terminated(self.run_id, status)

pytorch_lightning/loggers/neptune.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,18 @@ class NeptuneLogger(LightningLoggerBase):
2929
To log experiment data in online mode, NeptuneLogger requries an API key:
3030
"""
3131

32-
def __init__(self, api_key: Optional[str] = None, project_name: Optional[str] = None,
33-
close_after_fit: Optional[bool] = True, offline_mode: bool = False,
32+
def __init__(self,
33+
api_key: Optional[str] = None,
34+
project_name: Optional[str] = None,
35+
close_after_fit: Optional[bool] = True,
36+
offline_mode: bool = True,
3437
experiment_name: Optional[str] = None,
35-
upload_source_files: Optional[List[str]] = None, params: Optional[Dict[str, Any]] = None,
36-
properties: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, **kwargs):
38+
upload_source_files: Optional[List[str]] = None,
39+
params: Optional[Dict[str, Any]] = None,
40+
properties: Optional[Dict[str, Any]] = None,
41+
tags: Optional[List[str]] = None,
42+
**kwargs):
3743
r"""
38-
3944
Initialize a neptune.ai logger.
4045
4146
.. note:: Requires either an API Key (online mode) or a local directory path (offline mode)
@@ -135,8 +140,8 @@ def any_lightning_module_function_or_hook(...):
135140
"namespace/project_name" for example "tom/minst-classification".
136141
If None, the value of NEPTUNE_PROJECT environment variable will be taken.
137142
You need to create the project in https://neptune.ai first.
138-
offline_mode: Optional default False. If offline_mode=True no logs will be send
139-
to neptune. Usually used for debug purposes.
143+
offline_mode: Optional default True. If offline_mode=True no logs will be send
144+
to neptune. Usually used for debug and test purposes.
140145
close_after_fit: Optional default True. If close_after_fit=False the experiment
141146
will not be closed after training and additional metrics,
142147
images or artifacts can be logged. Also, remember to close the experiment explicitly
@@ -243,6 +248,7 @@ def log_metrics(
243248

244249
@rank_zero_only
245250
def finalize(self, status: str) -> None:
251+
super().finalize(status)
246252
if self.close_after_fit:
247253
self.experiment.stop()
248254

pytorch_lightning/loggers/tensorboard.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
class TensorBoardLogger(LightningLoggerBase):
1616
r"""
17-
1817
Log to local file system in TensorBoard format
1918
2019
Implemented using :class:`torch.utils.tensorboard.SummaryWriter`. Logs are saved to
@@ -40,18 +39,19 @@ class TensorBoardLogger(LightningLoggerBase):
4039
"""
4140
NAME_CSV_TAGS = 'meta_tags.csv'
4241

43-
def __init__(
44-
self, save_dir: str, name: Optional[str] = "default",
45-
version: Optional[Union[int, str]] = None, **kwargs
46-
):
42+
def __init__(self,
43+
save_dir: str,
44+
name: Optional[str] = "default",
45+
version: Optional[Union[int, str]] = None,
46+
**kwargs):
4747
super().__init__()
4848
self.save_dir = save_dir
4949
self._name = name
5050
self._version = version
5151

5252
self._experiment = None
5353
self.tags = {}
54-
self.kwargs = kwargs
54+
self._kwargs = kwargs
5555

5656
@property
5757
def root_dir(self) -> str:
@@ -92,7 +92,7 @@ def experiment(self) -> SummaryWriter:
9292
return self._experiment
9393

9494
os.makedirs(self.root_dir, exist_ok=True)
95-
self._experiment = SummaryWriter(log_dir=self.log_dir, **self.kwargs)
95+
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
9696
return self._experiment
9797

9898
@rank_zero_only
@@ -127,6 +127,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
127127

128128
@rank_zero_only
129129
def save(self) -> None:
130+
super().save()
130131
try:
131132
self.experiment.flush()
132133
except AttributeError:

pytorch_lightning/loggers/test_tube.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@ class TestTubeLogger(LightningLoggerBase):
1818

1919
__test__ = False
2020

21-
def __init__(
22-
self, save_dir: str, name: str = "default", description: Optional[str] = None,
23-
debug: bool = False, version: Optional[int] = None, create_git_tag: bool = False
24-
):
21+
def __init__(self,
22+
save_dir: str,
23+
name: str = "default",
24+
description: Optional[str] = None,
25+
debug: bool = False,
26+
version: Optional[int] = None,
27+
create_git_tag: bool = False):
2528
r"""
2629
2730
Example
@@ -105,19 +108,22 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
105108

106109
@rank_zero_only
107110
def save(self) -> None:
111+
super().save()
108112
# TODO: HACK figure out where this is being set to true
109113
self.experiment.debug = self.debug
110114
self.experiment.save()
111115

112116
@rank_zero_only
113117
def finalize(self, status: str) -> None:
118+
super().finalize(status)
114119
# TODO: HACK figure out where this is being set to true
115120
self.experiment.debug = self.debug
116121
self.save()
117122
self.close()
118123

119124
@rank_zero_only
120125
def close(self) -> None:
126+
super().save()
121127
# TODO: HACK figure out where this is being set to true
122128
self.experiment.debug = self.debug
123129
if not self.debug:

pytorch_lightning/loggers/trains.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,9 @@ def log_artifact(
295295
delete_after_upload=delete_after_upload
296296
)
297297

298-
def save(self) -> None:
299-
pass
300-
301298
@rank_zero_only
302299
def finalize(self, status: str = None) -> None:
300+
# super().finalize(status)
303301
if self.bypass_mode() or not self._trains:
304302
return
305303

pytorch_lightning/loggers/wandb.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,18 @@ class WandbLogger(LightningLoggerBase):
4646
trainer = Trainer(logger=wandb_logger)
4747
"""
4848

49-
def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None,
50-
offline: bool = False, id: Optional[str] = None, anonymous: bool = False,
51-
version: Optional[str] = None, project: Optional[str] = None,
52-
tags: Optional[List[str]] = None, log_model: bool = False,
53-
experiment=None, entity=None):
49+
def __init__(self,
50+
name: Optional[str] = None,
51+
save_dir: Optional[str] = None,
52+
offline: bool = False,
53+
id: Optional[str] = None,
54+
anonymous: bool = False,
55+
version: Optional[str] = None,
56+
project: Optional[str] = None,
57+
tags: Optional[List[str]] = None,
58+
log_model: bool = False,
59+
experiment=None,
60+
entity=None):
5461
super().__init__()
5562
self._name = name
5663
self._save_dir = save_dir

pytorch_lightning/trainer/evaluation_loop.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,13 @@ def run_evaluation(self, test_mode: bool = False):
370370

371371
# run evaluation
372372
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
373-
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
374-
eval_results)
373+
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)
375374

376375
# add metrics to prog bar
377376
self.add_tqdm_metrics(prog_bar_metrics)
378377

379378
# log results of test
380-
if test_mode and self.proc_rank == 0 and len(callback_metrics) > 0:
379+
if test_mode and self.proc_rank == 0:
381380
print('-' * 80)
382381
print('TEST RESULTS')
383382
pprint(callback_metrics)

pytorch_lightning/trainer/trainer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,7 @@ def __init__(
293293

294294
# benchmarking
295295
self.benchmark = benchmark
296-
if benchmark:
297-
torch.backends.cudnn.benchmark = True
296+
torch.backends.cudnn.benchmark = self.benchmark
298297

299298
# Transfer params
300299
self.num_nodes = num_nodes

0 commit comments

Comments
 (0)