Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horovod: adjust base LR used by schedulers to scale with the number of workers #2626

Merged
merged 6 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed Horovod backend to scale LR schedlers with the optimizer ([#2626](https://github.com/PyTorchLightning/pytorch-lightning/pull/2626))

- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))

## [0.8.5] - 2020-07-09
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
import random
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence

from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -298,8 +299,13 @@ def horovod_train(self, model):
for param_group in optimizer.param_groups:
param_group['lr'] *= hvd.size()

# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
for scheduler in self.lr_schedulers:
scheduler = scheduler['scheduler']
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]

if self.use_amp:
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
Expand Down
37 changes: 36 additions & 1 deletion tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import subprocess
import sys

from unittest.mock import patch

import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -113,7 +116,6 @@ def test_horovod_multi_gpu(tmpdir):
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_horovod_transfer_batch_to_gpu(tmpdir):

class TestTrainingStepModel(EvalModelTemplate):
def training_step(self, batch, *args, **kwargs):
x, y = batch
Expand Down Expand Up @@ -175,3 +177,36 @@ def get_optimizer_params(optimizer):
assert get_model_params(model.generator) != get_model_params(model.discriminator)
assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0])
assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1])


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

num_workers = 8
init_lr = hparams.get('learning_rate') * num_workers

with patch('pytorch_lightning.trainer.distrib_parts.hvd.size') as mock_hvd_size:
mock_hvd_size.return_value = 8

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0.5,
limit_train_batches=0.2,
distributed_backend='horovod'
)
results = trainer.fit(model)
assert results == 1

adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0]
adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0]

# Called ones after end of epoch with gamma=0.1
assert pytest.approx(init_lr * 0.1) == adjusted_lr1

# Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1
assert pytest.approx(init_lr * 0.1) == adjusted_lr2