Skip to content

Commit 36b0fc0

Browse files
Merge branch 'master' into add_seed_for_reproducibility
2 parents a77c702 + 1df0d2d commit 36b0fc0

40 files changed

+162
-64
lines changed

.github/workflows/docker-builds.yml

+15-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
name: Publish Docker
1+
name: Publish Docker Releases
22
on:
33
push:
4-
branches:
4+
branches:
55
- master
66
release:
77
types:
@@ -15,22 +15,27 @@ jobs:
1515
python_version: [3.6, 3.7, 3.8]
1616
pytorch_version: [1.1, 1.2, 1.3, 1.4, 1.5]
1717
steps:
18-
- name: Extract Current Tag
19-
if: contains(github.ref, 'refs/tags/')
20-
id: get_version
21-
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
18+
- name: Extract branch name
19+
if: contains(github.ref, 'refs/tags/') != true
20+
shell: bash
21+
run: echo "##[set-output name=branch;]$(echo ${GITHUB_REF#refs/heads/})"
22+
id: extract_tag
23+
- name: Extract Tag name
24+
if: contains(github.ref, 'refs/tags')
25+
shell: bash
26+
run: echo "##[set-output name=tag;]$(echo ${GITHUB_REF#refs/tags/})"
2227
- uses: actions/checkout@v2
2328
- name: Publish Releases to Docker
2429
# only on releases
2530
uses: elgohr/[email protected]
26-
if: contains(github.ref, 'refs/tags/') && !contains(${{ steps.get_version.outputs.VERSION }}, 'rc') %% !contains(${{ steps.get_version.outputs.VERSION }}, 'dev')
31+
if: contains(github.ref, 'refs/tags/') && !contains(${{ steps.extract_tag.outputs.tag }}, 'rc') && !contains(${{ steps.extract_tag.outputs.tag }}, 'dev')
2732
with:
2833
name: pytorchlightning/pytorch_lightning
2934
username: ${{ secrets.DOCKER_USERNAME }}
3035
password: ${{ secrets.DOCKER_PASSWORD }}
3136
dockerfile: docker/Dockerfile
32-
buildargs: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.VERSION }}
33-
tags: "${{ steps.get_version.outputs.VERSION }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }},stable-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}"
37+
buildargs: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.extract_tag.outputs.tag }}
38+
tags: "${{ steps.extract_tag.outputs.tag }}-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }},stable-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}"
3439
- name: Publish Master
3540
# publish master
3641
uses: elgohr/[email protected]
@@ -40,5 +45,5 @@ jobs:
4045
username: ${{ secrets.DOCKER_USERNAME }}
4146
password: ${{ secrets.DOCKER_PASSWORD }}
4247
dockerfile: docker/Dockerfile
43-
buildargs: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.get_version.outputs.VERSION }}
48+
buildargs: PYTHON_VERSION=${{ matrix.python_version }},PYTORCH_VERSION=${{ matrix.pytorch_version }},LIGHTNING_VERSION=${{ steps.extract_branch.outputs.branch }}
4449
tags: "nightly-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}"

CHANGELOG.md

+20-2
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,34 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020

2121
- Added option to provide seed to random generators to ensure reproducibility ([#1572](https://github.com/PyTorchLightning/pytorch-lightning/pull/1572))
2222

23+
- Enable `NeptuneLogger` to work with `distributed_backend=ddp` ([#1753](https://github.com/PyTorchLightning/pytorch-lightning/pull/1753))
24+
25+
2326
### Changed
2427

2528
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
2629

2730
- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
2831

32+
- Don't convert `namedtuple` to `tuple` when transferring the batch to target device ([#1589](https://github.com/PyTorchLightning/pytorch-lightning/pull/1589))
33+
34+
- Allow passing hparams as keyword argument to LightningModule when loading from checkpoint ([#1639](https://github.com/PyTorchLightning/pytorch-lightning/pull/1639))
35+
2936
### Deprecated
3037

3138
### Removed
3239

3340
### Fixed
3441

35-
- Fixed ModelCheckpoint not None checking filepath ([1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654))
42+
- Fixed broken link in PR template ([#1675](https://github.com/PyTorchLightning/pytorch-lightning/pull/1675))
43+
44+
- Fixed ModelCheckpoint not None checking filepath ([#1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654))
3645

37-
- Trainer now calls `on_load_checkpoint()` when resuming from a checkpoint ([1666](https://github.com/PyTorchLightning/pytorch-lightning/pull/1666))
46+
- Trainer now calls `on_load_checkpoint()` when resuming from a checkpoint ([#1666](https://github.com/PyTorchLightning/pytorch-lightning/pull/1666))
47+
48+
- Fixed sampler logic for ddp with iterable dataset ([#1734](https://github.com/PyTorchLightning/pytorch-lightning/pull/1734))
49+
50+
- Fixed `_reset_eval_dataloader()` for IterableDataset ([#1560](https://github.com/PyTorchLightning/pytorch-lightning/pull/1560))
3851

3952
- Fixed Horovod distributed backend to set the `root_gpu` property ([#1669](https://github.com/PyTorchLightning/pytorch-lightning/pull/1669))
4053

@@ -46,6 +59,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4659

4760
- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748))
4861

62+
- Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719))
63+
64+
- Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561))
65+
66+
4967
## [0.7.5] - 2020-04-27
5068

5169
### Changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
[![ReadTheDocs](https://readthedocs.org/projects/pytorch-lightning/badge/?version=0.7.5)](https://pytorch-lightning.readthedocs.io/en/stable/)
1616
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-lightning/shared_invite/enQtODU5ODIyNTUzODQwLTFkMDg5Mzc1MDBmNjEzMDgxOTVmYTdhYjA1MDdmODUyOTg2OGQ1ZWZkYTQzODhhNzdhZDA3YmNhMDhlMDY4YzQ)
1717
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/PytorchLightning/pytorch-lightning/blob/master/LICENSE)
18-
[![Next Release](https://img.shields.io/badge/Next%20Release-May%2006-<COLOR>.svg)](https://shields.io/)
18+
[![Next Release](https://img.shields.io/badge/Next%20Release-May%2020-<COLOR>.svg)](https://shields.io/)
1919

2020
<!--
2121
removed until codecov badge isn't empy. likely a config error showing nothing on master.

docs/source/multi_gpu.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ This will make your code scale to any arbitrary number of GPUs or TPUs with Ligh
4646
# with lightning
4747
def forward(self, x):
4848
z = torch.Tensor(2, 3)
49-
z = z.type_as(x)
49+
z = z.type_as(x, device=self.device)
50+
51+
Every LightningModule knows what device it is on. You can access that reference via `self.device`.
5052

5153
Remove samplers
5254
^^^^^^^^^^^^^^^

pl_examples/domain_templates/generative_adversarial_net.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def on_epoch_end(self):
173173
# log sampled images
174174
sample_imgs = self(z)
175175
grid = torchvision.utils.make_grid(sample_imgs)
176-
self.logger.experiment.add_image(f'generated_images', grid, self.current_epoch)
176+
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
177177

178178

179179
def main(hparams):

pytorch_lightning/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Root package info."""
22

3-
__version__ = '0.7.5'
3+
__version__ = '0.7.6rc1'
44
__author__ = 'William Falcon et al.'
55
__author_email__ = '[email protected]'
66
__license__ = 'Apache-2.0'
@@ -34,7 +34,8 @@
3434
import logging as python_logging
3535

3636
_logger = python_logging.getLogger("lightning")
37-
python_logging.basicConfig(level=python_logging.INFO)
37+
_logger.addHandler(python_logging.StreamHandler())
38+
_logger.setLevel(python_logging.INFO)
3839

3940
try:
4041
# This variable is injected in the __builtins__ by the build

pytorch_lightning/callbacks/lr_logger.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _extract_lr(self, trainer, interval):
8080
param_groups = scheduler['scheduler'].optimizer.param_groups
8181
if len(param_groups) != 1:
8282
for i, pg in enumerate(param_groups):
83-
lr, key = pg['lr'], f'{name}/{i + 1}'
83+
lr, key = pg['lr'], f'{name}/pg{i + 1}'
8484
self.lrs[key].append(lr)
8585
latest_stat[key] = lr
8686
else:
@@ -109,7 +109,7 @@ def _find_names(self, lr_schedulers):
109109
param_groups = sch.optimizer.param_groups
110110
if len(param_groups) != 1:
111111
for i, pg in enumerate(param_groups):
112-
temp = name + '/pg' + str(i + 1)
112+
temp = f'{name}/pg{i + 1}'
113113
names.append(temp)
114114
else:
115115
names.append(name)

pytorch_lightning/callbacks/progress.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def total_val_batches(self) -> int:
9393
"""
9494
trainer = self.trainer
9595
total_val_batches = 0
96-
if trainer.fast_dev_run:
96+
if trainer.fast_dev_run and trainer.val_dataloaders is not None:
9797
total_val_batches = len(trainer.val_dataloaders)
9898
elif not self.trainer.disable_validation:
9999
is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0

pytorch_lightning/core/lightning.py

+3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def __init__(self, *args, **kwargs):
7272

7373
self.hparams = None
7474

75+
#: device reference
76+
self.device = None
77+
7578
def print(self, *args, **kwargs) -> None:
7679
r"""
7780
Prints only from process 0. Use this in any distributed mode to log only once.

pytorch_lightning/core/model_saving.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.core.saving import * # noqa: F403
78

89
rank_zero_warn("`model_saving` module has been renamed to `saving` since v0.6.0."
910
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.core.saving import * # noqa: F403

pytorch_lightning/core/root_module.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.core.lightning import * # noqa: F403
78

89
rank_zero_warn("`root_module` module has been renamed to `lightning` since v0.6.0."
910
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.core.lightning import * # noqa: F403

pytorch_lightning/logging/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.loggers import * # noqa: F403
78

89
rank_zero_warn("`logging` package has been renamed to `loggers` since v0.7.0"
910
" The deprecated package name will be removed in v0.9.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.loggers import * # noqa: F403

pytorch_lightning/logging/comet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
"""
44

55
from pytorch_lightning.utilities import rank_zero_warn
6+
from pytorch_lightning.loggers.comet import CometLogger # noqa: F403
67

78
rank_zero_warn("`logging.comet` module has been renamed to `loggers.comet` since v0.7.0."
89
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
9-
10-
from pytorch_lightning.loggers.comet import CometLogger # noqa: F403

pytorch_lightning/logging/mlflow.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
"""
44

55
from pytorch_lightning.utilities import rank_zero_warn
6+
from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: F403
67

78
rank_zero_warn("`logging.mlflow` module has been renamed to `loggers.mlflow` since v0.7.0."
89
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
9-
10-
from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: F403

pytorch_lightning/logging/neptune.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
"""
44

55
from pytorch_lightning.utilities import rank_zero_warn
6+
from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F403
67

78
rank_zero_warn("`logging.neptune` module has been renamed to `loggers.neptune` since v0.7.0."
89
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
9-
10-
from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F403

pytorch_lightning/logging/test_tube.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
"""
44

55
from pytorch_lightning.utilities import rank_zero_warn
6+
from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: F403
67

78
rank_zero_warn("`logging.test_tube` module has been renamed to `loggers.test_tube` since v0.7.0."
89
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
9-
10-
from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: F403

pytorch_lightning/logging/wandb.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
"""
44

55
from pytorch_lightning.utilities import rank_zero_warn
6+
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F403
67

78
rank_zero_warn("`logging.wandb` module has been renamed to `loggers.wandb` since v0.7.0."
89
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)
9-
10-
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F403

pytorch_lightning/pt_overrides/override_data_parallel.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.overrides.data_parallel import ( # noqa: F402
8+
get_a_var, parallel_apply, LightningDataParallel, LightningDistributedDataParallel)
79

810
rank_zero_warn("`override_data_parallel` module has been renamed to `data_parallel` since v0.6.0."
911
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.overrides.data_parallel import ( # noqa: F402
12-
get_a_var, parallel_apply, LightningDataParallel, LightningDistributedDataParallel)

pytorch_lightning/root_module/decorators.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.core.decorators import * # noqa: F403
78

89
rank_zero_warn("`root_module.decorators` module has been renamed to `core.decorators` since v0.6.0."
910
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.core.decorators import * # noqa: F403

pytorch_lightning/root_module/grads.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.core.grads import * # noqa: F403
78

89
rank_zero_warn("`root_module.grads` module has been renamed to `core.grads` since v0.6.0."
910
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.core.grads import * # noqa: F403

pytorch_lightning/root_module/hooks.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.core.hooks import * # noqa: F403
78

89
rank_zero_warn("`root_module.hooks` module has been renamed to `core.hooks` since v0.6.0."
910
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.core.hooks import * # noqa: F403

pytorch_lightning/root_module/memory.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.core.memory import * # noqa: F403
78

89
rank_zero_warn("`root_module.memory` module has been renamed to `core.memory` since v0.6.0."
910
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.core.memory import * # noqa: F403

pytorch_lightning/root_module/model_saving.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.core.saving import * # noqa: F403
78

89
rank_zero_warn("`root_module.model_saving` module has been renamed to `core.saving` since v0.6.0."
910
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.core.saving import * # noqa: F403

pytorch_lightning/root_module/root_module.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55

66
from pytorch_lightning.utilities import rank_zero_warn
7+
from pytorch_lightning.core.lightning import * # noqa: F403
78

89
rank_zero_warn("`root_module.root_module` module has been renamed to `core.lightning` since v0.6.0."
910
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
10-
11-
from pytorch_lightning.core.lightning import * # noqa: F403

pytorch_lightning/trainer/distrib_data_parallel.py

+1
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def ddp_train(self, process_idx, model):
344344
# copy model to each gpu
345345
if self.on_gpu:
346346
self.root_gpu = process_idx
347+
self.device = torch.device('cuda', self.root_gpu)
347348
torch.cuda.set_device(self.root_gpu)
348349
model.cuda(self.root_gpu)
349350

pytorch_lightning/trainer/distrib_parts.py

+5
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def copy_trainer_model_properties(self, model):
432432
m.use_tpu = self.use_tpu
433433
m.tpu_local_core_rank = self.tpu_local_core_rank
434434
m.tpu_global_core_rank = self.tpu_global_core_rank
435+
m.device = self.device
435436

436437
def transfer_batch_to_tpu(self, batch):
437438
return self.__transfer_data_to_device(batch, device='tpu')
@@ -483,6 +484,7 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):
483484

484485
def single_gpu_train(self, model):
485486
model.cuda(self.root_gpu)
487+
self.device = torch.device('cuda', self.root_gpu)
486488

487489
# CHOOSE OPTIMIZER
488490
# allow for lr schedulers as well
@@ -499,6 +501,7 @@ def single_gpu_train(self, model):
499501
def tpu_train(self, tpu_core_idx, model):
500502
# put model on tpu
501503
model.to(xm.xla_device())
504+
self.device = xm.xla_device()
502505

503506
# get the appropriate tpu ranks
504507
self.tpu_local_core_rank = xm.get_local_ordinal()
@@ -536,6 +539,7 @@ def dp_train(self, model):
536539
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
537540

538541
model.cuda(self.root_gpu)
542+
self.device = torch.device('cuda', self.root_gpu)
539543

540544
# hack forward to do autocast for the user
541545
model_autocast_original_forward = model.forward
@@ -575,6 +579,7 @@ def horovod_train(self, model):
575579
assert self.root_gpu == hvd.local_rank()
576580
torch.cuda.set_device(self.root_gpu)
577581
model.cuda(self.root_gpu)
582+
self.device = torch.device('cuda', self.root_gpu)
578583

579584
# avoid duplicating progress bar
580585
if hvd.rank() != 0 and self.progress_bar_callback is not None:

pytorch_lightning/trainer/evaluation_loop.py

+4
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,10 @@ def run_evaluation(self, test_mode: bool = False):
360360
dataloaders = self.val_dataloaders
361361
max_batches = self.num_val_batches
362362

363+
# enable fast_dev_run without val loop
364+
if dataloaders is None:
365+
return
366+
363367
# cap max batches to 1 when using fast_dev_run
364368
if self.fast_dev_run:
365369
max_batches = 1

0 commit comments

Comments
 (0)