Skip to content

Commit 1369012

Browse files
tgaddairwilliamFalconBorda
authored
Horovod: adjust base LR used by schedulers to scale with the number of workers (#2626)
* Horovod: Adjust base LR used by schedulers to match that of the optimizer after scaling by number of workers * Added unit test * Removed debug statements * Updated changelog * Apply suggestions from code review Co-authored-by: William Falcon <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent bda7cf1 commit 1369012

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28+
- Fixed Horovod backend to scale LR schedlers with the optimizer ([#2626](https://github.com/PyTorchLightning/pytorch-lightning/pull/2626))
29+
2830
- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))
2931

3032
## [0.8.5] - 2020-07-09

pytorch_lightning/trainer/distrib_parts.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import time
1111
import random
1212
import torch
13+
from torch.optim.lr_scheduler import _LRScheduler
1314
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence
1415

1516
from pytorch_lightning.core.lightning import LightningModule
@@ -298,8 +299,13 @@ def horovod_train(self, model):
298299
for param_group in optimizer.param_groups:
299300
param_group['lr'] *= hvd.size()
300301

302+
# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
303+
for scheduler in self.lr_schedulers:
304+
scheduler = scheduler['scheduler']
305+
if isinstance(scheduler, _LRScheduler):
306+
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]
307+
301308
if self.use_amp:
302-
# An example
303309
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
304310
self.optimizers = optimizers
305311
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

tests/models/test_horovod.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import subprocess
66
import sys
77

8+
from unittest.mock import patch
9+
10+
import numpy as np
811
import pytest
912
import torch
1013

@@ -113,7 +116,6 @@ def test_horovod_multi_gpu(tmpdir):
113116
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
114117
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
115118
def test_horovod_transfer_batch_to_gpu(tmpdir):
116-
117119
class TestTrainingStepModel(EvalModelTemplate):
118120
def training_step(self, batch, *args, **kwargs):
119121
x, y = batch
@@ -175,3 +177,36 @@ def get_optimizer_params(optimizer):
175177
assert get_model_params(model.generator) != get_model_params(model.discriminator)
176178
assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0])
177179
assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1])
180+
181+
182+
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
183+
def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
184+
hparams = EvalModelTemplate.get_default_hparams()
185+
model = EvalModelTemplate(**hparams)
186+
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
187+
188+
num_workers = 8
189+
init_lr = hparams.get('learning_rate') * num_workers
190+
191+
with patch('pytorch_lightning.trainer.distrib_parts.hvd.size') as mock_hvd_size:
192+
mock_hvd_size.return_value = 8
193+
194+
# fit model
195+
trainer = Trainer(
196+
default_root_dir=tmpdir,
197+
max_epochs=1,
198+
limit_val_batches=0.5,
199+
limit_train_batches=0.2,
200+
distributed_backend='horovod'
201+
)
202+
results = trainer.fit(model)
203+
assert results == 1
204+
205+
adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0]
206+
adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0]
207+
208+
# Called ones after end of epoch with gamma=0.1
209+
assert pytest.approx(init_lr * 0.1) == adjusted_lr1
210+
211+
# Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1
212+
assert pytest.approx(init_lr * 0.1) == adjusted_lr2

0 commit comments

Comments
 (0)