Skip to content

Commit 3cb149f

Browse files
williamFalconAdrian Wälchli
and
Adrian Wälchli
authored
Removes need to unsqueeze from dp (#1319)
* removes need to unsqueeze from dp * removes need to unsqueeze from dp * fixed examples * added auto unsqueeze * added auto unsqueeze * added auto unsqueeze * added auto unsqueeze * Update pytorch_lightning/overrides/data_parallel.py Co-Authored-By: Adrian Wälchli <[email protected]> * fixed dp parse * fixed dp parse Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 6b41b5c commit 3cb149f

File tree

6 files changed

+24
-37
lines changed

6 files changed

+24
-37
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Changed
2727

28+
- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
2829
- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
2930
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
3031
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))

pl_examples/basic_examples/lightning_module_template.py

-9
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,6 @@ def training_step(self, batch, batch_idx):
111111
# calculate loss
112112
loss_val = self.loss(y, y_hat)
113113

114-
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
115-
if self.trainer.use_dp or self.trainer.use_ddp2:
116-
loss_val = loss_val.unsqueeze(0)
117-
118114
tqdm_dict = {'train_loss': loss_val}
119115
output = OrderedDict({
120116
'loss': loss_val,
@@ -145,11 +141,6 @@ def validation_step(self, batch, batch_idx):
145141
if self.on_gpu:
146142
val_acc = val_acc.cuda(loss_val.device.index)
147143

148-
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
149-
if self.trainer.use_dp or self.trainer.use_ddp2:
150-
loss_val = loss_val.unsqueeze(0)
151-
val_acc = val_acc.unsqueeze(0)
152-
153144
output = OrderedDict({
154145
'val_loss': loss_val,
155146
'val_acc': val_acc,

pl_examples/domain_templates/gan.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
9999
if optimizer_idx == 0:
100100
# sample noise
101101
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
102-
103-
# match gpu device (or keep as cpu)
104-
if self.on_gpu:
105-
z = z.cuda(imgs.device.index)
102+
z = z.type_as(imgs)
106103

107104
# generate images
108105
self.generated_imgs = self(z)
@@ -115,8 +112,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
115112
# ground truth result (ie: all fake)
116113
# put on GPU because we created this tensor inside training_loop
117114
valid = torch.ones(imgs.size(0), 1)
118-
if self.on_gpu:
119-
valid = valid.cuda(imgs.device.index)
115+
valid = valid.type_as(imgs)
120116

121117
# adversarial loss is binary cross-entropy
122118
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
@@ -134,15 +130,13 @@ def training_step(self, batch, batch_idx, optimizer_idx):
134130

135131
# how well can it label as real?
136132
valid = torch.ones(imgs.size(0), 1)
137-
if self.on_gpu:
138-
valid = valid.cuda(imgs.device.index)
133+
valid = valid.type_as(imgs)
139134

140135
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
141136

142137
# how well can it label as fake?
143138
fake = torch.zeros(imgs.size(0), 1)
144-
if self.on_gpu:
145-
fake = fake.cuda(imgs.device.index)
139+
fake = fake.type_as(fake)
146140

147141
fake_loss = self.adversarial_loss(
148142
self.discriminator(self.generated_imgs.detach()), fake)
@@ -174,9 +168,7 @@ def train_dataloader(self):
174168

175169
def on_epoch_end(self):
176170
z = torch.randn(8, self.hparams.latent_dim)
177-
# match gpu device (or keep as cpu)
178-
if self.on_gpu:
179-
z = z.cuda(self.last_imgs.device.index)
171+
z = z.type_as(self.last_imgs)
180172

181173
# log sampled images
182174
sample_imgs = self(z)

pl_examples/domain_templates/reinforse_learn_Qnet.py

-3
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,6 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O
277277
# calculates training loss
278278
loss = self.dqn_mse_loss(batch)
279279

280-
if self.trainer.use_dp or self.trainer.use_ddp2:
281-
loss = loss.unsqueeze(0)
282-
283280
if done:
284281
self.total_reward = self.episode_reward
285282
self.episode_reward = 0

pl_examples/full_examples/imagenet/imagenet_example.py

-12
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,6 @@ def training_step(self, batch, batch_idx):
4646
loss_val = F.cross_entropy(output, target)
4747
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
4848

49-
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
50-
if self.trainer.use_dp or self.trainer.use_ddp2:
51-
loss_val = loss_val.unsqueeze(0)
52-
acc1 = acc1.unsqueeze(0)
53-
acc5 = acc5.unsqueeze(0)
54-
5549
tqdm_dict = {'train_loss': loss_val}
5650
output = OrderedDict({
5751
'loss': loss_val,
@@ -69,12 +63,6 @@ def validation_step(self, batch, batch_idx):
6963
loss_val = F.cross_entropy(output, target)
7064
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
7165

72-
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
73-
if self.trainer.use_dp or self.trainer.use_ddp2:
74-
loss_val = loss_val.unsqueeze(0)
75-
acc1 = acc1.unsqueeze(0)
76-
acc5 = acc5.unsqueeze(0)
77-
7866
output = OrderedDict({
7967
'val_loss': loss_val,
8068
'val_acc1': acc1,

pytorch_lightning/overrides/data_parallel.py

+18
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def _worker(i, module, input, kwargs, device=None):
163163

164164
else:
165165
output = module.validation_step(*input, **kwargs)
166+
167+
if module.use_dp or module.use_ddp2:
168+
auto_squeeze_dim_zeros(output)
166169
# ---------------
167170

168171
with lock:
@@ -199,3 +202,18 @@ def _worker(i, module, input, kwargs, device=None):
199202
raise output
200203
outputs.append(output)
201204
return outputs
205+
206+
207+
def auto_squeeze_dim_zeros(output):
208+
"""
209+
In DP or DDP2 we need to unsqueeze dim 0
210+
:param output:
211+
:return:
212+
"""
213+
for k, v in output.items():
214+
if not isinstance(v, torch.Tensor):
215+
continue
216+
217+
is_scalar = v.dim() == 0
218+
if is_scalar:
219+
output[k] = output[k].unsqueeze(0)

0 commit comments

Comments
 (0)