-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: wandb logger #627
Changes from 8 commits
477a9e4
bda27f0
6ca6ba6
e705a79
cd7c71c
04dd339
9305c50
98d48e4
0d833c3
16349e4
e52b7ac
f869d15
d780f5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
Log using `W&B <https://www.wandb.com>`_ | ||
|
||
.. code-block:: python | ||
|
||
from pytorch_lightning.logging import WandbLogger | ||
wandb_logger = WandbLogger( | ||
name="my_new_nun", # Optional, display name | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If all are optimal, there is no need to write it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the intent was just to show sample use if anybody wanted to use those args as it does not appear in the doc and they would have to look at the docstring of |
||
save_dir="my_folder", # Optional, path to save data | ||
offline=False, # Optional, run offline (can sync later) | ||
version="run_12345", # Optional, used to restart previous run | ||
id="run_12345", # Optional, same as version | ||
anonymous=False, # Optional, enable or disable anonymous logging | ||
project="bert", # Optional, project to which run belongs to | ||
tags=["tag1", "tag2"] # Optional, tags associated with run | ||
) | ||
trainer = Trainer(logger=wandb_logger) | ||
|
||
|
||
Use the logger anywhere in you LightningModule as follows: | ||
|
||
.. code-block:: python | ||
|
||
def train_step(...): | ||
# example | ||
self.logger.experiment.whatever_wandb_supports(...) | ||
|
||
def any_lightning_module_function_or_hook(...): | ||
self.logger.experiment.whatever_wandb_supports(...) | ||
|
||
""" | ||
|
||
import os | ||
|
||
try: | ||
import wandb | ||
except ImportError: | ||
raise ImportError('Missing wandb package.') | ||
|
||
from .base import LightningLoggerBase, rank_zero_only | ||
|
||
|
||
class WandbLogger(LightningLoggerBase): | ||
""" | ||
Logger for W&B. | ||
|
||
Args: | ||
name (str): display name for the run. | ||
save_dir (str): path where data is saved. | ||
offline (bool): run offline (data can be streamed later to wandb servers). | ||
id or version (str): sets the version, mainly used to resume a previous run. | ||
anonymous (bool): enables or explicitly disables anonymous logging. | ||
project (str): the name of the project to which this run will belong. | ||
tags (list of str): tags associated with this run. | ||
""" | ||
|
||
def __init__(self, name=None, save_dir=None, offline=False, id=None, anonymous=False, | ||
version=None, project=None, tags=None): | ||
super().__init__() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it correct way? I am use to write the child class name... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By default super will look for one class above the current class, so here it will call |
||
self._name = name | ||
self._save_dir = save_dir | ||
self._anonymous = "allow" if anonymous else None | ||
self._id = version or id | ||
self._tags = tags | ||
self._project = project | ||
self._experiment = None | ||
self._offline = offline | ||
|
||
def __getstate__(self): | ||
state = self.__dict__.copy() | ||
# cannot be pickled | ||
state['_experiment'] = None | ||
# args needed to reload correct experiment | ||
state['_id'] = self.experiment.id | ||
return state | ||
|
||
@property | ||
def experiment(self): | ||
if self._experiment is None: | ||
if self._offline: | ||
os.environ["WANDB_MODE"] = "dryrun" | ||
self._experiment = wandb.init( | ||
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, | ||
id=self._id, resume="allow", tags=self._tags) | ||
return self._experiment | ||
|
||
def watch(self, model, log="gradients", log_freq=100): | ||
wandb.watch(model, log, log_freq) | ||
|
||
@rank_zero_only | ||
def log_hyperparams(self, params): | ||
self.experiment.config.update(params) | ||
|
||
@rank_zero_only | ||
def log_metrics(self, metrics, step=None): | ||
metrics["global_step"] = step | ||
self.experiment.history.add(metrics) | ||
|
||
def save(self): | ||
pass | ||
|
||
@rank_zero_only | ||
def finalize(self, status='success'): | ||
try: | ||
exit_code = 0 if status == 'success' else 1 | ||
wandb.join(exit_code) | ||
except TypeError: | ||
wandb.join() | ||
|
||
@property | ||
def name(self): | ||
return self.experiment.project_name() | ||
|
||
@property | ||
def version(self): | ||
return self.experiment.id |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,5 @@ check-manifest | |
# test_tube # already installed in main req. | ||
mlflow | ||
comet_ml | ||
wandb | ||
twine==1.13.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,6 +193,63 @@ def test_comet_pickle(tmpdir, monkeypatch): | |
trainer2.logger.log_metrics({"acc": 1.0}) | ||
|
||
|
||
def test_wandb_logger(tmpdir): | ||
"""Verify that basic functionality of wandb logger works.""" | ||
tutils.reset_seed() | ||
|
||
try: | ||
from pytorch_lightning.logging import WandbLogger | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove try/except here, it has to be tested |
||
except ModuleNotFoundError: | ||
return | ||
|
||
hparams = tutils.get_hparams() | ||
model = LightningTestModel(hparams) | ||
|
||
wandb_dir = os.path.join(tmpdir, "wandb") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that the sub folder is not needed, the temp dir is always clean and unique... |
||
|
||
logger = WandbLogger(save_dir=wandb_dir, anonymous=True) | ||
|
||
trainer_options = dict( | ||
default_save_path=tmpdir, | ||
max_epochs=1, | ||
train_percent_check=0.01, | ||
logger=logger | ||
) | ||
trainer = Trainer(**trainer_options) | ||
result = trainer.fit(model) | ||
|
||
print('result finished') | ||
assert result == 1, "Training failed" | ||
|
||
|
||
def test_wandb_pickle(tmpdir): | ||
"""Verify that pickling trainer with wandb logger works.""" | ||
tutils.reset_seed() | ||
|
||
try: | ||
from pytorch_lightning.logging import WandbLogger | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove try/except here, it has to be tested |
||
except ModuleNotFoundError: | ||
return | ||
|
||
# hparams = tutils.get_hparams() | ||
# model = LightningTestModel(hparams) | ||
|
||
wandb_dir = os.path.join(tmpdir, "wandb") | ||
|
||
logger = WandbLogger(save_dir=wandb_dir, anonymous=True) | ||
|
||
trainer_options = dict( | ||
default_save_path=tmpdir, | ||
max_epochs=1, | ||
logger=logger | ||
) | ||
|
||
trainer = Trainer(**trainer_options) | ||
pkl_bytes = pickle.dumps(trainer) | ||
trainer2 = pickle.loads(pkl_bytes) | ||
trainer2.logger.log_metrics({"acc": 1.0}) | ||
|
||
|
||
def test_tensorboard_logger(tmpdir): | ||
"""Verify that basic functionality of Tensorboard logger works.""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider write it as doctest