Skip to content

Commit 6b41b5c

Browse files
authored
feat(wandb): save models on wandb (#1339)
* feat(wandb): save models on wandb * docs(changelog): allow to upload models on W&B
1 parent 04935ea commit 6b41b5c

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
2222
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
2323
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
24+
- Allow to upload models on W&B ([#1339](https://github.com/PyTorchLightning/pytorch-lightning/pull/1339))
2425

2526
### Changed
2627

pytorch_lightning/loggers/wandb.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class WandbLogger(LightningLoggerBase):
3333
anonymous (bool): enables or explicitly disables anonymous logging.
3434
project (str): the name of the project to which this run will belong.
3535
tags (list of str): tags associated with this run.
36+
log_model (bool): save checkpoints in wandb dir to upload on W&B servers.
3637
3738
Example
3839
--------
@@ -48,7 +49,8 @@ class WandbLogger(LightningLoggerBase):
4849
def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None,
4950
offline: bool = False, id: Optional[str] = None, anonymous: bool = False,
5051
version: Optional[str] = None, project: Optional[str] = None,
51-
tags: Optional[List[str]] = None, experiment=None, entity=None):
52+
tags: Optional[List[str]] = None, log_model: bool = False,
53+
experiment=None, entity=None):
5254
super().__init__()
5355
self._name = name
5456
self._save_dir = save_dir
@@ -59,6 +61,7 @@ def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None,
5961
self._experiment = experiment
6062
self._offline = offline
6163
self._entity = entity
64+
self._log_model = log_model
6265

6366
def __getstate__(self):
6467
state = self.__dict__.copy()
@@ -85,6 +88,9 @@ def experiment(self) -> Run:
8588
self._experiment = wandb.init(
8689
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
8790
id=self._id, resume='allow', tags=self._tags, entity=self._entity)
91+
# save checkpoints in wandb dir to upload on W&B servers
92+
if self._log_model:
93+
self.save_dir = self._experiment.dir
8894
return self._experiment
8995

9096
def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):

0 commit comments

Comments
 (0)