From 49403a0db5f1af18d289717b4d057a0f9904af38 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 31 Mar 2020 10:04:58 -0400 Subject: [PATCH 01/10] removes need to unsqueeze from dp --- pytorch_lightning/overrides/data_parallel.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 6b922f65a526d..6566e2c39f9e0 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -163,6 +163,9 @@ def _worker(i, module, input, kwargs, device=None): else: output = module.validation_step(*input, **kwargs) + + if module.use_dp or module.use_ddp2: + output['loss'] = output['loss'].unsqueeze() # --------------- with lock: From 0fb19b5e86f92fd296c08e37c1efbf092a4999af Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 31 Mar 2020 10:05:51 -0400 Subject: [PATCH 02/10] removes need to unsqueeze from dp --- pytorch_lightning/overrides/data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 6566e2c39f9e0..7733d5a515c75 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -165,7 +165,7 @@ def _worker(i, module, input, kwargs, device=None): output = module.validation_step(*input, **kwargs) if module.use_dp or module.use_ddp2: - output['loss'] = output['loss'].unsqueeze() + output['loss'] = output['loss'].unsqueeze(0) # --------------- with lock: From bab36f941c78263db8717474c7fb069bb9d7094e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 31 Mar 2020 10:32:07 -0400 Subject: [PATCH 03/10] fixed examples --- .../lightning_module_template.py | 9 --------- pl_examples/domain_templates/gan.py | 18 +++++------------- .../domain_templates/reinforse_learn_Qnet.py | 3 --- .../full_examples/imagenet/imagenet_example.py | 12 ------------ 4 files changed, 5 insertions(+), 37 deletions(-) diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index f436db2872395..9c65bac9ac7e0 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -111,10 +111,6 @@ def training_step(self, batch, batch_idx): # calculate loss loss_val = self.loss(y, y_hat) - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - tqdm_dict = {'train_loss': loss_val} output = OrderedDict({ 'loss': loss_val, @@ -145,11 +141,6 @@ def validation_step(self, batch, batch_idx): if self.on_gpu: val_acc = val_acc.cuda(loss_val.device.index) - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) - output = OrderedDict({ 'val_loss': loss_val, 'val_acc': val_acc, diff --git a/pl_examples/domain_templates/gan.py b/pl_examples/domain_templates/gan.py index 68e6053e7e822..3661b3637ecb4 100644 --- a/pl_examples/domain_templates/gan.py +++ b/pl_examples/domain_templates/gan.py @@ -99,10 +99,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # sample noise z = torch.randn(imgs.shape[0], self.hparams.latent_dim) - - # match gpu device (or keep as cpu) - if self.on_gpu: - z = z.cuda(imgs.device.index) + z = z.type_as(imgs) # generate images self.generated_imgs = self(z) @@ -115,8 +112,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop valid = torch.ones(imgs.size(0), 1) - if self.on_gpu: - valid = valid.cuda(imgs.device.index) + valid = valid.type_as(imgs) # adversarial loss is binary cross-entropy g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) @@ -134,15 +130,13 @@ def training_step(self, batch, batch_idx, optimizer_idx): # how well can it label as real? valid = torch.ones(imgs.size(0), 1) - if self.on_gpu: - valid = valid.cuda(imgs.device.index) + valid = valid.type_as(imgs) real_loss = self.adversarial_loss(self.discriminator(imgs), valid) # how well can it label as fake? fake = torch.zeros(imgs.size(0), 1) - if self.on_gpu: - fake = fake.cuda(imgs.device.index) + fake = fake.type_as(fake) fake_loss = self.adversarial_loss( self.discriminator(self.generated_imgs.detach()), fake) @@ -174,9 +168,7 @@ def train_dataloader(self): def on_epoch_end(self): z = torch.randn(8, self.hparams.latent_dim) - # match gpu device (or keep as cpu) - if self.on_gpu: - z = z.cuda(self.last_imgs.device.index) + z = z.type_as(self.last_imgs) # log sampled images sample_imgs = self(z) diff --git a/pl_examples/domain_templates/reinforse_learn_Qnet.py b/pl_examples/domain_templates/reinforse_learn_Qnet.py index 4585c108d5cfb..5a797d7d89a9d 100644 --- a/pl_examples/domain_templates/reinforse_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforse_learn_Qnet.py @@ -277,9 +277,6 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O # calculates training loss loss = self.dqn_mse_loss(batch) - if self.trainer.use_dp or self.trainer.use_ddp2: - loss = loss.unsqueeze(0) - if done: self.total_reward = self.episode_reward self.episode_reward = 0 diff --git a/pl_examples/full_examples/imagenet/imagenet_example.py b/pl_examples/full_examples/imagenet/imagenet_example.py index ad8f90f5a10b6..52c5cf0642f0b 100644 --- a/pl_examples/full_examples/imagenet/imagenet_example.py +++ b/pl_examples/full_examples/imagenet/imagenet_example.py @@ -46,12 +46,6 @@ def training_step(self, batch, batch_idx): loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - acc1 = acc1.unsqueeze(0) - acc5 = acc5.unsqueeze(0) - tqdm_dict = {'train_loss': loss_val} output = OrderedDict({ 'loss': loss_val, @@ -69,12 +63,6 @@ def validation_step(self, batch, batch_idx): loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - acc1 = acc1.unsqueeze(0) - acc5 = acc5.unsqueeze(0) - output = OrderedDict({ 'val_loss': loss_val, 'val_acc1': acc1, From 3fb70e5fa4715da9a74596cb8288f872805e5649 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 31 Mar 2020 10:39:48 -0400 Subject: [PATCH 04/10] added auto unsqueeze --- pytorch_lightning/overrides/data_parallel.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 7733d5a515c75..16d30686c866b 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -165,7 +165,7 @@ def _worker(i, module, input, kwargs, device=None): output = module.validation_step(*input, **kwargs) if module.use_dp or module.use_ddp2: - output['loss'] = output['loss'].unsqueeze(0) + auto_squeeze_dim_zeros(output) # --------------- with lock: @@ -202,3 +202,15 @@ def _worker(i, module, input, kwargs, device=None): raise output outputs.append(output) return outputs + + +def auto_squeeze_dim_zeros(output): + """ + In DP or DDP2 we need to unsqueeze dim 0 + :param output: + :return: + """ + for k, v in output: + is_scalar = len(v.size()) + if is_scalar: + output[k] = output[k].unsqueeze(0) From f2bea75922afbfa4f09beab58d842fb61bbebf4c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 31 Mar 2020 10:45:11 -0400 Subject: [PATCH 05/10] added auto unsqueeze --- pytorch_lightning/overrides/data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 16d30686c866b..566f748b258b9 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -210,7 +210,7 @@ def auto_squeeze_dim_zeros(output): :param output: :return: """ - for k, v in output: + for k, v in output.items(): is_scalar = len(v.size()) if is_scalar: output[k] = output[k].unsqueeze(0) From 9b5b6b8abee46794c45bb12a9dc1d9681c6d1646 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 31 Mar 2020 10:49:23 -0400 Subject: [PATCH 06/10] added auto unsqueeze --- pytorch_lightning/overrides/data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 566f748b258b9..e08deaa61b2cf 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -211,6 +211,6 @@ def auto_squeeze_dim_zeros(output): :return: """ for k, v in output.items(): - is_scalar = len(v.size()) + is_scalar = len(v.size()) == 0 if is_scalar: output[k] = output[k].unsqueeze(0) From fe4468546bd50823b7cbe91ad5cafaf93db6ac91 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 31 Mar 2020 10:51:54 -0400 Subject: [PATCH 07/10] added auto unsqueeze --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 02110f6e68ecc..e5e19e4a20b10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319)) - Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307)) - Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260)) From 337d67d683396318cc07a93395ef0f5f15df5b13 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 1 Apr 2020 12:14:13 -0400 Subject: [PATCH 08/10] Update pytorch_lightning/overrides/data_parallel.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Adrian Wälchli --- pytorch_lightning/overrides/data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index e08deaa61b2cf..b900492982d73 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -211,6 +211,6 @@ def auto_squeeze_dim_zeros(output): :return: """ for k, v in output.items(): - is_scalar = len(v.size()) == 0 + is_scalar = v.dim() == 0 if is_scalar: output[k] = output[k].unsqueeze(0) From 9d3343af515cc284419dfc0a5f4bd13a5304e4fa Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 1 Apr 2020 23:50:42 -0400 Subject: [PATCH 09/10] fixed dp parse --- pytorch_lightning/overrides/data_parallel.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index b900492982d73..45ea81f1830cc 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -211,6 +211,9 @@ def auto_squeeze_dim_zeros(output): :return: """ for k, v in output.items(): + if not isinstance(v, torch.Tensor): + pass + is_scalar = v.dim() == 0 if is_scalar: output[k] = output[k].unsqueeze(0) From 6cccacf2af77bce421ef638bf2e700faf9ec966a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 1 Apr 2020 23:54:26 -0400 Subject: [PATCH 10/10] fixed dp parse --- pytorch_lightning/overrides/data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 45ea81f1830cc..168cdf7e1798b 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -212,7 +212,7 @@ def auto_squeeze_dim_zeros(output): """ for k, v in output.items(): if not isinstance(v, torch.Tensor): - pass + continue is_scalar = v.dim() == 0 if is_scalar: