Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

horovod mode increase lr #2574

Closed
ruotianluo opened this issue Jul 10, 2020 · 13 comments
Closed

horovod mode increase lr #2574

ruotianluo opened this issue Jul 10, 2020 · 13 comments
Labels
bug Something isn't working discussion In a discussion stage help wanted Open to be worked on

Comments

@ruotianluo
Copy link
Contributor

ruotianluo commented Jul 10, 2020

Not really a 🐛 Bug

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/distrib_parts.py#L299

Under horovod mode, the learning rate will automatically be increased by hvd.size().

This behavior is different from ddp, so it may confuse the users.

@ruotianluo ruotianluo added bug Something isn't working help wanted Open to be worked on labels Jul 10, 2020
@Borda
Copy link
Member

Borda commented Jul 10, 2020

mind check @tgaddair @SkafteNicki ^^

@SkafteNicki
Copy link
Member

I was not aware of this. It's seems that it is a good practice, that really can help boost training speed (see this paper https://arxiv.org/abs/1706.02677 and this issue horovod/horovod#384). So maybe we should implement something similar for ddp backend.

@tgaddair
Copy link
Contributor

Yes, in Horovod (and I believe DDP), increasing number of workers is analogous to increasing the total batch size during training. As such, scaling the learning rate proportionately is considered a best practice. It's good to handle it internally to the distributed_backend, because some backends may behave differently. For example, in Horovod, enabling the Adasum optimizer only requires scaling by the number of GPUs per host.

@ruotianluo
Copy link
Contributor Author

I agree it is a good practice. However if it's not the only way, I don't think it should be the default, Especially without notifying the users.

It could instead be an argument of the trainer.

@tgaddair
Copy link
Contributor

I think making it configurable is reasonable. However, I do think it should be enabled by default. Part of the goal of the Trainer abstraction is to make distributed training accessible to people who are not familiar with distributed training concepts / best practices.

For most users, unless they are using custom learning rate schedules or unusual optimizers, they will want to scale the learning rate. At the same time, most users would not know to do this themselves, so I fear without enabling it by default, they would not do so, and their models would converge worse as a result.

@ruotianluo
Copy link
Contributor Author

ruotianluo commented Jul 11, 2020

BTW, there may be a problem when lr_scheduler is LambdaLR. It seems LambdaLR will collect the lrs in optimizer and save as base_lrs. The lambda function will take place on the base_lrs. Even you change the lr later, the lr scheduler would ignore it.

https://github.com/pytorch/pytorch/blob/879cf0b15a54c7848ae710e3d0ec62c4a9d7d3dd/torch/optim/lr_scheduler.py#L43

LambdaLR scheduler I believe is a commonly used scheduler. For now I think it's safer to delete that line for now, and then think of what is the best way to implement it.

Of course, correct me if I am wrong.

@tgaddair
Copy link
Contributor

Hey @ruotianluo, that's a good point regarding interaction with LambdaLR and other LR schedulers. Can you take a look at #2626 and see if it addresses your concern?

@ruotianluo
Copy link
Contributor Author

@tgaddair I still want to defend against scaling learning rate by default. By primitive search, it doesn't seem to me that in nlp, people do the same learning rate scaling. Bert uses batch size 256 and learning rate 1e-4; Roberta uses batch size 8k and max learning rate is 4e-4/6e-4(depending on the model size). I think it may be related to optimizer(in nlp it's usually adam). I don't know if this fact can convince you.

@tgaddair
Copy link
Contributor

Hey @ruotianluo, even when training BERT with Horovod, it's common practice to scale the learning rate. See:

Fundamentally, when you add more workers, you are increasing the batch size. That holds true whether it is a vision task, NLP, or other scenarios. So you need to account for that somehow (most commonly through LR scaling, though I imagine other means are possible as well).

I do agree we should make this configurable, though. I'm interested in putting together a separate PR for this, but it should include changes to DDP as well.

@ruotianluo
Copy link
Contributor Author

https://arxiv.org/pdf/1904.00962.pdf. This paper uses square root scaling for bert(and also imagenet classification too). Aand albert(from google) uses this approach (https://arxiv.org/pdf/1909.11942.pdf).

The first link you provide doesn't have any results. For the second, I didn't see any quantitative results either how that would affects.(and it is not merged yet.)

Do other frameworks do learning rate scaling by default too(Keras, fastai?)? If it's common across other libraries, I think it's fine too.

@Borda Borda added the discussion In a discussion stage label Jul 23, 2020
@tgaddair
Copy link
Contributor

tgaddair commented Jul 23, 2020

Hey @ruotianluo, in my experience, frameworks that expose distributed training to users as an API (like tf.distributed) will mention in their docs that it's good practice to scale the LR, but will leave it to the user to do so (this is what we do with Horovod as well).

However, frameworks that attempt to completely abstract away distributed training (like PyTorch Lightning is seeking to do) should provide a good reasonable default.

I agree with you that in practice, it may be that linearly scaling the learning rate does not provide the best model performance, in which cases the researchers will often hand-tune the combination of learning rate and total batch size (i.e., number of workers) to obtain the best performance. To support that, there is definitely a need to make learning rate adjustment configurable.

At the same time, whatever solution we come up with needs to be backend-agnostic. One of the selling points of PL is the ability to swap out different distributed backends. If we couple the LR scaling to the backend (e.g., require the user to put lr * hvd.size() in the LightningModule), we lose a lot of the benefit.

With that in mind, here's what I'm currently thinking could be a good solution:

  1. Provide a good reasonable default for users who are not experts in distributed training (linear learning rate scaling) for Horovod and DDP.
  2. Provide an optional method in the LightningModule that allows the user to adjust the learning rate as a function of the number of workers, independent of the specific backend being used, which will override the default in (1):
class MyModule(LightningModule):

    ...

    def adjust_learning_rate(self, base_lr, world_size):
        return base_lr * sqrt(world_size)

@ruotianluo @williamFalcon @Borda @SkafteNicki what do you think?

@Borda
Copy link
Member

Borda commented Jul 23, 2020

cc: @PyTorchLightning/core-contributors

@Borda
Copy link
Member

Borda commented Aug 4, 2020

shall be resolved in #2626

@Borda Borda closed this as completed Aug 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working discussion In a discussion stage help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

4 participants