Skip to content

Commit ec7fc97

Browse files
borisdaymavanpeltwilliamFalcon
committed
Feature: wandb logger (#627)
* Basic wandb support * refactor(wandb): remove unused variables and document logger * docs(wandb): explain how to use WandbLogger * test(wandb): add tests for WandbLogger * feat(wandb): add save_dir * fix(wandb): allow pickle of logger * fix(wandb): save logs in custom directory * test(wandb): test import * docs(wandb): simplify docstring and use doctest * test: increase number of epochs for satisfactory accuracy * test(test_load_model_from_checkpoint): ensure we load last checkpoint Co-authored-by: Chris Van Pelt <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent f7db44e commit ec7fc97

File tree

8 files changed

+139
-6
lines changed

8 files changed

+139
-6
lines changed

.run_local_tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ rm -rf _ckpt_*
33
rm -rf tests/save_dir*
44
rm -rf tests/mlruns_*
55
rm -rf tests/cometruns*
6+
rm -rf tests/wandb*
67
rm -rf tests/tests/*
78
rm -rf lightning_logs
89
coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ Lightning also adds a text column with all the hyperparameters for this experime
306306
- [Save a snapshot of all hyperparameters](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#save-a-snapshot-of-all-hyperparameters)
307307
- [Snapshot code for a training run](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#snapshot-code-for-a-training-run)
308308
- [Write logs file to csv every k batches](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#write-logs-file-to-csv-every-k-batches)
309+
- [Logging on W&B](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#w&b)
309310
- [Logging experiment data to Neptune](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#neptune-support)
310311

311312
#### Training loop

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def setup(app):
294294
MOCK_REQUIRE_PACKAGES.append(pkg.rstrip())
295295

296296
# TODO: better parse from package since the import name and package name may differ
297-
MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'sklearn', 'test_tube', 'mlflow', 'comet_ml', 'neptune']
297+
MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'sklearn', 'test_tube', 'mlflow', 'comet_ml', 'wandb', 'neptune']
298298
autodoc_mock_imports = MOCK_REQUIRE_PACKAGES + MOCK_MANUAL_PACKAGES
299299
# for mod_name in MOCK_REQUIRE_PACKAGES:
300300
# sys.modules[mod_name] = mock.Mock()

pytorch_lightning/logging/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def __init__(self, hparams):
180180
except ImportError:
181181
pass
182182

183+
try:
184+
from .wandb import WandbLogger
185+
except ImportError:
186+
pass
183187
try:
184188
# needed to prevent ImportError and duplicated logs.
185189
environ["COMET_DISABLE_AUTO_LOGGING"] = "1"

pytorch_lightning/logging/wandb.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
Log using `W&B <https://www.wandb.com>`_
3+
4+
.. code-block:: python
5+
6+
>>> from pytorch_lightning.logging import WandbLogger
7+
>>> from pytorch_lightning import Trainer
8+
>>> wandb_logger = WandbLogger()
9+
>>> trainer = Trainer(logger=wandb_logger)
10+
11+
12+
Use the logger anywhere in you LightningModule as follows:
13+
14+
.. code-block:: python
15+
16+
def train_step(...):
17+
# example
18+
self.logger.experiment.whatever_wandb_supports(...)
19+
20+
def any_lightning_module_function_or_hook(...):
21+
self.logger.experiment.whatever_wandb_supports(...)
22+
23+
"""
24+
25+
import os
26+
27+
try:
28+
import wandb
29+
except ImportError:
30+
raise ImportError('Missing wandb package.')
31+
32+
from .base import LightningLoggerBase, rank_zero_only
33+
34+
35+
class WandbLogger(LightningLoggerBase):
36+
"""
37+
Logger for W&B.
38+
39+
Args:
40+
name (str): display name for the run.
41+
save_dir (str): path where data is saved.
42+
offline (bool): run offline (data can be streamed later to wandb servers).
43+
id or version (str): sets the version, mainly used to resume a previous run.
44+
anonymous (bool): enables or explicitly disables anonymous logging.
45+
project (str): the name of the project to which this run will belong.
46+
tags (list of str): tags associated with this run.
47+
"""
48+
49+
def __init__(self, name=None, save_dir=None, offline=False, id=None, anonymous=False,
50+
version=None, project=None, tags=None):
51+
super().__init__()
52+
self._name = name
53+
self._save_dir = save_dir
54+
self._anonymous = "allow" if anonymous else None
55+
self._id = version or id
56+
self._tags = tags
57+
self._project = project
58+
self._experiment = None
59+
self._offline = offline
60+
61+
def __getstate__(self):
62+
state = self.__dict__.copy()
63+
# cannot be pickled
64+
state['_experiment'] = None
65+
# args needed to reload correct experiment
66+
state['_id'] = self.experiment.id
67+
return state
68+
69+
@property
70+
def experiment(self):
71+
if self._experiment is None:
72+
if self._offline:
73+
os.environ["WANDB_MODE"] = "dryrun"
74+
self._experiment = wandb.init(
75+
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
76+
id=self._id, resume="allow", tags=self._tags)
77+
return self._experiment
78+
79+
def watch(self, model, log="gradients", log_freq=100):
80+
wandb.watch(model, log, log_freq)
81+
82+
@rank_zero_only
83+
def log_hyperparams(self, params):
84+
self.experiment.config.update(params)
85+
86+
@rank_zero_only
87+
def log_metrics(self, metrics, step=None):
88+
metrics["global_step"] = step
89+
self.experiment.history.add(metrics)
90+
91+
def save(self):
92+
pass
93+
94+
@rank_zero_only
95+
def finalize(self, status='success'):
96+
try:
97+
exit_code = 0 if status == 'success' else 1
98+
wandb.join(exit_code)
99+
except TypeError:
100+
wandb.join()
101+
102+
@property
103+
def name(self):
104+
return self.experiment.project_name()
105+
106+
@property
107+
def version(self):
108+
return self.experiment.id

tests/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ check-manifest
88
# test_tube # already installed in main req.
99
mlflow
1010
comet_ml
11+
wandb
1112
neptune-client
1213
twine==1.13.0
1314
pillow<7.0.0

tests/test_logging.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,14 @@ def test_comet_pickle(tmpdir, monkeypatch):
192192
trainer2 = pickle.loads(pkl_bytes)
193193
trainer2.logger.log_metrics({"acc": 1.0})
194194

195+
def test_wandb_logger(tmpdir):
196+
"""Verify that basic functionality of wandb logger works."""
197+
tutils.reset_seed()
198+
199+
from pytorch_lightning.logging import WandbLogger
200+
201+
wandb_dir = os.path.join(tmpdir, "wandb")
202+
logger = WandbLogger(save_dir=wandb_dir, anonymous=True)
195203

196204
def test_neptune_logger(tmpdir):
197205
"""Verify that basic functionality of neptune logger works."""
@@ -201,7 +209,6 @@ def test_neptune_logger(tmpdir):
201209

202210
hparams = tutils.get_hparams()
203211
model = LightningTestModel(hparams)
204-
205212
logger = NeptuneLogger(offline_mode=True)
206213

207214
trainer_options = dict(
@@ -216,6 +223,13 @@ def test_neptune_logger(tmpdir):
216223
print('result finished')
217224
assert result == 1, "Training failed"
218225

226+
def test_wandb_pickle(tmpdir):
227+
"""Verify that pickling trainer with wandb logger works."""
228+
tutils.reset_seed()
229+
230+
from pytorch_lightning.logging import WandbLogger
231+
wandb_dir = str(tmpdir)
232+
logger = WandbLogger(save_dir=wandb_dir, anonymous=True)
219233

220234
def test_neptune_pickle(tmpdir):
221235
"""Verify that pickling trainer with neptune logger works."""
@@ -227,6 +241,7 @@ def test_neptune_pickle(tmpdir):
227241
# model = LightningTestModel(hparams)
228242

229243
logger = NeptuneLogger(offline_mode=True)
244+
230245
trainer_options = dict(
231246
default_save_path=tmpdir,
232247
max_epochs=1,

tests/test_restore_models.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_load_model_from_checkpoint(tmpdir):
106106

107107
trainer_options = dict(
108108
show_progress_bar=False,
109-
max_epochs=5,
109+
max_epochs=2,
110110
train_percent_check=0.4,
111111
val_percent_check=0.2,
112112
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
@@ -120,9 +120,12 @@ def test_load_model_from_checkpoint(tmpdir):
120120

121121
# correct result and ok accuracy
122122
assert result == 1, 'training failed to complete'
123-
pretrained_model = LightningTestModel.load_from_checkpoint(
124-
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_4.ckpt")
125-
)
123+
124+
# load last checkpoint
125+
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
126+
if not os.path.isfile(last_checkpoint):
127+
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
128+
pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)
126129

127130
# test that hparams loaded correctly
128131
for k, v in vars(hparams).items():

0 commit comments

Comments
 (0)