Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix user warning produced by apex + scheduler combination #1873

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))

## [0.7.6] - 2020-05-16

### Added
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def ddp_train(self, process_idx, model):
if self.use_amp and not self.use_native_amp:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

# DDP2 uses all GPUs on the machine
if self.distributed_backend == 'ddp':
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def single_gpu_train(self, model):
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

self.run_pretrain_routine(model)

Expand Down Expand Up @@ -559,6 +560,7 @@ def dp_train(self, model):
f' We recommend you switch to ddp if you want to use amp')
else:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.reinit_scheduler_properties(optimizers, self.lr_schedulers)

# create list of device ids
device_ids = self.data_parallel_device_ids
Expand Down Expand Up @@ -599,6 +601,7 @@ def horovod_train(self, model):
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ def configure_schedulers(self, schedulers: list):
'is a invalid input.')
return lr_schedulers

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
for optimizer in optimizers:
scheduler = scheduler['scheduler']
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if mro == optim.lr_scheduler._LRScheduler:
idx = i
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)


class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`
Expand Down