Skip to content

Commit cf2d32d

Browse files
authored
fix bugs in semantic segmentation example (#1824)
* Update unet.py * Update semantic_segmentation.py
1 parent 1265b2f commit cf2d32d

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

pl_examples/domain_templates/semantic_segmentation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def validation_step(self, batch, batch_idx):
165165
return {'val_loss': loss_val}
166166

167167
def validation_epoch_end(self, outputs):
168-
loss_val = sum(output['val_loss'] for output in outputs) / len(outputs)
168+
loss_val = torch.stack([x['val_loss'] for x in outputs]).mean()
169169
log_dict = {'val_loss': loss_val}
170170
return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict}
171171

pl_examples/models/unet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
feats *= 2
3434

3535
for _ in range(num_layers - 1):
36-
layers.append(Up(feats, feats // 2), bilinear)
36+
layers.append(Up(feats, feats // 2, bilinear))
3737
feats //= 2
3838

3939
layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

0 commit comments

Comments
 (0)