Skip to content

Commit 0ad3e8b

Browse files
changed to absolute imports and added docs (#881)
1 parent f44dfb3 commit 0ad3e8b

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

pl_examples/full_examples/semantic_segmentation/models/unet/__init__.py

-4
This file was deleted.

pl_examples/full_examples/semantic_segmentation/models/unet/model.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,33 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from parts import DoubleConv, Down, Up
5+
from models.unet.parts import DoubleConv, Down, Up
66

77

88
class UNet(nn.Module):
99
'''
1010
Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation
1111
Link - https://arxiv.org/abs/1505.04597
12+
13+
Parameters:
14+
num_classes (int) - Number of output classes required (default 19 for KITTI dataset)
15+
bilinear (bool) - Whether to use bilinear interpolation or transposed
16+
convolutions for upsampling.
1217
'''
1318
def __init__(self, num_classes=19, bilinear=False):
1419
super().__init__()
15-
self.bilinear = bilinear
16-
self.num_classes = num_classes
1720
self.layer1 = DoubleConv(3, 64)
1821
self.layer2 = Down(64, 128)
1922
self.layer3 = Down(128, 256)
2023
self.layer4 = Down(256, 512)
2124
self.layer5 = Down(512, 1024)
2225

23-
self.layer6 = Up(1024, 512, bilinear=self.bilinear)
24-
self.layer7 = Up(512, 256, bilinear=self.bilinear)
25-
self.layer8 = Up(256, 128, bilinear=self.bilinear)
26-
self.layer9 = Up(128, 64, bilinear=self.bilinear)
26+
self.layer6 = Up(1024, 512, bilinear=bilinear)
27+
self.layer7 = Up(512, 256, bilinear=bilinear)
28+
self.layer8 = Up(256, 128, bilinear=bilinear)
29+
self.layer9 = Up(128, 64, bilinear=bilinear)
2730

28-
self.layer10 = nn.Conv2d(64, self.num_classes, kernel_size=1)
31+
self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1)
2932

3033
def forward(self, x):
3134
x1 = self.layer1(x)

0 commit comments

Comments
 (0)