Skip to content

Commit 9a501ff

Browse files
committed
Fix NeptuneLogger to work in ddp mode
1 parent 3a64260 commit 9a501ff

File tree

2 files changed

+76
-46
lines changed

2 files changed

+76
-46
lines changed

pytorch_lightning/loggers/neptune.py

+39-24
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class NeptuneLogger(LightningLoggerBase):
3232
3333
The Neptune logger can be used in the online mode or offline (silent) mode.
3434
To log experiment data in online mode, :class:`NeptuneLogger` requires an API key.
35-
In offline mode, Neptune will log to a local directory.
35+
In offline mode, the logger does not connect to Neptune.
3636
3737
**ONLINE MODE**
3838
@@ -83,7 +83,7 @@ class NeptuneLogger(LightningLoggerBase):
8383
... self.logger.experiment.log_artifact('model_checkpoint.pt', ...)
8484
... self.logger.experiment.whatever_neptune_supports(...)
8585
86-
If you want to log objects after the training is finished use ``close_after_train=False``:
86+
If you want to log objects after the training is finished use ``close_after_fit=False``:
8787
8888
.. code-block:: python
8989
@@ -135,7 +135,7 @@ class NeptuneLogger(LightningLoggerBase):
135135
"namespace/project_name" for example "tom/minst-classification".
136136
If ``None``, the value of `NEPTUNE_PROJECT` environment variable will be taken.
137137
You need to create the project in https://neptune.ai first.
138-
offline_mode: Optional default False. If ``True`` no logs will be sent
138+
offline_mode: Optional default ``False``. If ``True`` no logs will be sent
139139
to Neptune. Usually used for debug purposes.
140140
close_after_fit: Optional default ``True``. If ``False`` the experiment
141141
will not be closed after training and additional metrics,
@@ -167,6 +167,7 @@ class NeptuneLogger(LightningLoggerBase):
167167
Tags are displayed in the experiment’s Details section and can be viewed
168168
in the experiments view as a column.
169169
"""
170+
170171
def __init__(self,
171172
api_key: Optional[str] = None,
172173
project_name: Optional[str] = None,
@@ -188,24 +189,20 @@ def __init__(self,
188189
self.params = params
189190
self.properties = properties
190191
self.tags = tags
191-
self._experiment = None
192192
self._kwargs = kwargs
193+
self._experiment_id = None
194+
self._experiment = self._create_or_get_experiment()
193195

194-
if offline_mode:
195-
self.mode = 'offline'
196-
neptune.init(project_qualified_name='dry-run/project',
197-
backend=neptune.OfflineBackend())
198-
else:
199-
self.mode = 'online'
200-
neptune.init(api_token=self.api_key,
201-
project_qualified_name=self.project_name)
202-
203-
log.info(f'NeptuneLogger was initialized in {self.mode} mode')
196+
log.info(f'NeptuneLogger will work in {"offline" if self.offline_mode else "online"} mode')
204197

205198
def __getstate__(self):
206199
state = self.__dict__.copy()
207-
# cannot be pickled
200+
201+
# Experiment cannot be pickled, and additionally its ID cannot be pickled in offline mode
208202
state['_experiment'] = None
203+
if self.offline_mode:
204+
state['_experiment_id'] = None
205+
209206
return state
210207

211208
@property
@@ -220,14 +217,11 @@ def experiment(self) -> Experiment:
220217
221218
"""
222219

220+
# Note that even though we initialize self._experiment in __init__,
221+
# it may still end up being None after being pickled and un-pickled
223222
if self._experiment is None:
224-
self._experiment = neptune.create_experiment(
225-
name=self.experiment_name,
226-
params=self.params,
227-
properties=self.properties,
228-
tags=self.tags,
229-
upload_source_files=self.upload_source_files,
230-
**self._kwargs)
223+
self._experiment = self._create_or_get_experiment()
224+
231225
return self._experiment
232226

233227
@rank_zero_only
@@ -261,14 +255,14 @@ def finalize(self, status: str) -> None:
261255

262256
@property
263257
def name(self) -> str:
264-
if self.mode == 'offline':
258+
if self.offline_mode:
265259
return 'offline-name'
266260
else:
267261
return self.experiment.name
268262

269263
@property
270264
def version(self) -> str:
271-
if self.mode == 'offline':
265+
if self.offline_mode:
272266
return 'offline-id-1234'
273267
else:
274268
return self.experiment.id
@@ -363,3 +357,24 @@ def append_tags(self, tags: Union[str, Iterable[str]]) -> None:
363357
if str(tags) == tags:
364358
tags = [tags] # make it as an iterable is if it is not yet
365359
self.experiment.append_tags(*tags)
360+
361+
def _create_or_get_experiment(self):
362+
if self.offline_mode:
363+
project = neptune.Session(backend=neptune.OfflineBackend()).get_project('dry-run/project')
364+
else:
365+
session = neptune.Session.with_default_backend(api_token=self.api_key)
366+
project = session.get_project(self.project_name)
367+
368+
if self._experiment_id is None:
369+
exp = project.create_experiment(
370+
name=self.experiment_name,
371+
params=self.params,
372+
properties=self.properties,
373+
tags=self.tags,
374+
upload_source_files=self.upload_source_files,
375+
**self._kwargs)
376+
else:
377+
exp = project.get_experiments(id=self._experiment_id)[0]
378+
379+
self._experiment_id = exp.id
380+
return exp

tests/loggers/test_neptune.py

+37-22
Original file line numberDiff line numberDiff line change
@@ -10,53 +10,68 @@
1010

1111
@patch('pytorch_lightning.loggers.neptune.neptune')
1212
def test_neptune_online(neptune):
13-
logger = NeptuneLogger(api_key='test', offline_mode=False, project_name='project')
14-
neptune.init.assert_called_once_with(api_token='test', project_qualified_name='project')
13+
logger = NeptuneLogger(api_key='test', project_name='project')
1514

16-
assert logger.name == neptune.create_experiment().name
17-
assert logger.version == neptune.create_experiment().id
15+
created_experiment = neptune.Session.with_default_backend().get_project().create_experiment()
16+
17+
# It's important to check if the internal variable _experiment was initialized in __init__.
18+
# Calling logger.experiment would cause a side-effect of initializing _experiment,
19+
# if it wasn't already initialized.
20+
assert logger._experiment == created_experiment
21+
assert logger.name == created_experiment.name
22+
assert logger.version == created_experiment.id
1823

1924

2025
@patch('pytorch_lightning.loggers.neptune.neptune')
21-
def test_neptune_additional_methods(neptune):
26+
def test_neptune_offline(neptune):
2227
logger = NeptuneLogger(offline_mode=True)
2328

29+
neptune.Session.assert_called_once_with(backend=neptune.OfflineBackend())
30+
assert logger.experiment == neptune.Session().get_project().create_experiment()
31+
32+
33+
@patch('pytorch_lightning.loggers.neptune.neptune')
34+
def test_neptune_additional_methods(neptune):
35+
logger = NeptuneLogger(api_key='test', project_name='project')
36+
37+
created_experiment = neptune.Session.with_default_backend().get_project().create_experiment()
38+
2439
logger.log_metric('test', torch.ones(1))
25-
neptune.create_experiment().log_metric.assert_called_once_with('test', torch.ones(1))
26-
neptune.create_experiment().log_metric.reset_mock()
40+
created_experiment.log_metric.assert_called_once_with('test', torch.ones(1))
41+
created_experiment.log_metric.reset_mock()
2742

2843
logger.log_metric('test', 1.0)
29-
neptune.create_experiment().log_metric.assert_called_once_with('test', 1.0)
30-
neptune.create_experiment().log_metric.reset_mock()
44+
created_experiment.log_metric.assert_called_once_with('test', 1.0)
45+
created_experiment.log_metric.reset_mock()
3146

3247
logger.log_metric('test', 1.0, step=2)
33-
neptune.create_experiment().log_metric.assert_called_once_with('test', x=2, y=1.0)
34-
neptune.create_experiment().log_metric.reset_mock()
48+
created_experiment.log_metric.assert_called_once_with('test', x=2, y=1.0)
49+
created_experiment.log_metric.reset_mock()
3550

3651
logger.log_text('test', 'text')
37-
neptune.create_experiment().log_metric.assert_called_once_with('test', 'text')
38-
neptune.create_experiment().log_metric.reset_mock()
52+
created_experiment.log_metric.assert_called_once_with('test', 'text')
53+
created_experiment.log_metric.reset_mock()
3954

4055
logger.log_image('test', 'image file')
41-
neptune.create_experiment().log_image.assert_called_once_with('test', 'image file')
42-
neptune.create_experiment().log_image.reset_mock()
56+
created_experiment.log_image.assert_called_once_with('test', 'image file')
57+
created_experiment.log_image.reset_mock()
4358

4459
logger.log_image('test', 'image file', step=2)
45-
neptune.create_experiment().log_image.assert_called_once_with('test', x=2, y='image file')
46-
neptune.create_experiment().log_image.reset_mock()
60+
created_experiment.log_image.assert_called_once_with('test', x=2, y='image file')
61+
created_experiment.log_image.reset_mock()
4762

4863
logger.log_artifact('file')
49-
neptune.create_experiment().log_artifact.assert_called_once_with('file', None)
64+
created_experiment.log_artifact.assert_called_once_with('file', None)
5065

5166
logger.set_property('property', 10)
52-
neptune.create_experiment().set_property.assert_called_once_with('property', 10)
67+
created_experiment.set_property.assert_called_once_with('property', 10)
5368

5469
logger.append_tags('one tag')
55-
neptune.create_experiment().append_tags.assert_called_once_with('one tag')
56-
neptune.create_experiment().append_tags.reset_mock()
70+
created_experiment.append_tags.assert_called_once_with('one tag')
71+
created_experiment.append_tags.reset_mock()
5772

5873
logger.append_tags(['two', 'tags'])
59-
neptune.create_experiment().append_tags.assert_called_once_with('two', 'tags')
74+
created_experiment.append_tags.assert_called_once_with('two', 'tags')
6075

6176

6277
def test_neptune_leave_open_experiment_after_fit(tmpdir):

0 commit comments

Comments
 (0)