From 6998f4416e024afebdfb4139f6884b9514965d09 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 15 Jul 2020 18:18:23 -0700 Subject: [PATCH 1/5] Horovod: Adjust base LR used by schedulers to match that of the optimizer after scaling by number of workers --- pytorch_lightning/trainer/distrib_parts.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 20edc0d60541a..b8ff9a03e75a5 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -10,6 +10,7 @@ import time import random import torch +from torch.optim.lr_scheduler import _LRScheduler from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence from pytorch_lightning.core.lightning import LightningModule @@ -298,8 +299,13 @@ def horovod_train(self, model): for param_group in optimizer.param_groups: param_group['lr'] *= hvd.size() + # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR + for scheduler in self.lr_schedulers: + scheduler = scheduler['scheduler'] + if isinstance(scheduler, _LRScheduler): + scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] + if self.use_amp: - # An example model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) From 42cc438d41f0d4cb94d2e0cce82b14514ec9600f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 16 Jul 2020 06:07:23 -0700 Subject: [PATCH 2/5] Added unit test --- tests/models/test_horovod.py | 40 +++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 77259df597ebc..02a6c8a244f19 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -5,6 +5,9 @@ import subprocess import sys +from unittest.mock import patch + +import numpy as np import pytest import torch @@ -113,7 +116,6 @@ def test_horovod_multi_gpu(tmpdir): @pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_horovod_transfer_batch_to_gpu(tmpdir): - class TestTrainingStepModel(EvalModelTemplate): def training_step(self, batch, *args, **kwargs): x, y = batch @@ -175,3 +177,39 @@ def get_optimizer_params(optimizer): assert get_model_params(model.generator) != get_model_params(model.discriminator) assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0]) assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + model.configure_optimizers = model.configure_optimizers__multiple_schedulers + + num_workers = 8 + init_lr = hparams.get('learning_rate') * num_workers + + with patch('pytorch_lightning.trainer.distrib_parts.hvd.size') as mock_hvd_size: + mock_hvd_size.return_value = 8 + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.5, + limit_train_batches=0.2, + distributed_backend='horovod' + ) + results = trainer.fit(model) + assert results == 1 + + adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] + adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] + + print(adjusted_lr1) + print(adjusted_lr2) + + # Called ones after end of epoch + assert pytest.approx(init_lr * 0.1) == adjusted_lr1 + + # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times + assert pytest.approx(init_lr * 0.1) == adjusted_lr2 From 83ba5bcce692d18bb80c5f0dd4b8b000ed2ea53e Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 16 Jul 2020 15:03:15 -0700 Subject: [PATCH 3/5] Removed debug statements --- tests/models/test_horovod.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 02a6c8a244f19..0d1f827187417 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -205,9 +205,6 @@ def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] - print(adjusted_lr1) - print(adjusted_lr2) - # Called ones after end of epoch assert pytest.approx(init_lr * 0.1) == adjusted_lr1 From 13b10a3d6a633c8e5239e0f057aaab5a38a6793c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 16 Jul 2020 15:09:40 -0700 Subject: [PATCH 4/5] Updated changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf200ea15f007..9537486b0c90b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed Horovod backend to scale LR schedlers with the optimizer ([#2626](https://github.com/PyTorchLightning/pytorch-lightning/pull/2626)) + ## [0.8.5] - 2020-07-09 From 560a4356c041b12fbc4ac2f23f0e9cf432f232c1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 23 Jul 2020 00:41:20 +0200 Subject: [PATCH 5/5] Apply suggestions from code review --- tests/models/test_horovod.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 0d1f827187417..989ed56475dc6 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -205,8 +205,8 @@ def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] - # Called ones after end of epoch + # Called ones after end of epoch with gamma=0.1 assert pytest.approx(init_lr * 0.1) == adjusted_lr1 - # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times + # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 assert pytest.approx(init_lr * 0.1) == adjusted_lr2