Skip to content

Commit 43ac63f

Browse files
Fix segmentation example (#876)
* removed torchvision model and added custom model * minor fix * Fixed relative imports issue
1 parent 6029fad commit 43ac63f

File tree

4 files changed

+117
-5
lines changed

4 files changed

+117
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# For relative imports to work in Python 3.6
2+
import os
3+
import sys
4+
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from parts import DoubleConv, Down, Up
6+
7+
8+
class UNet(nn.Module):
9+
'''
10+
Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation
11+
Link - https://arxiv.org/abs/1505.04597
12+
'''
13+
def __init__(self, num_classes=19, bilinear=False):
14+
super().__init__()
15+
self.bilinear = bilinear
16+
self.num_classes = num_classes
17+
self.layer1 = DoubleConv(3, 64)
18+
self.layer2 = Down(64, 128)
19+
self.layer3 = Down(128, 256)
20+
self.layer4 = Down(256, 512)
21+
self.layer5 = Down(512, 1024)
22+
23+
self.layer6 = Up(1024, 512, bilinear=self.bilinear)
24+
self.layer7 = Up(512, 256, bilinear=self.bilinear)
25+
self.layer8 = Up(256, 128, bilinear=self.bilinear)
26+
self.layer9 = Up(128, 64, bilinear=self.bilinear)
27+
28+
self.layer10 = nn.Conv2d(64, self.num_classes, kernel_size=1)
29+
30+
def forward(self, x):
31+
x1 = self.layer1(x)
32+
x2 = self.layer2(x1)
33+
x3 = self.layer3(x2)
34+
x4 = self.layer4(x3)
35+
x5 = self.layer5(x4)
36+
37+
x6 = self.layer6(x5, x4)
38+
x6 = self.layer7(x6, x3)
39+
x6 = self.layer8(x6, x2)
40+
x6 = self.layer9(x6, x1)
41+
42+
return self.layer10(x6)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class DoubleConv(nn.Module):
7+
'''
8+
Double Convolution and BN and ReLU
9+
(3x3 conv -> BN -> ReLU) ** 2
10+
'''
11+
def __init__(self, in_ch, out_ch):
12+
super().__init__()
13+
self.net = nn.Sequential(
14+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
15+
nn.BatchNorm2d(out_ch),
16+
nn.ReLU(inplace=True),
17+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
18+
nn.BatchNorm2d(out_ch),
19+
nn.ReLU(inplace=True)
20+
)
21+
22+
def forward(self, x):
23+
return self.net(x)
24+
25+
26+
class Down(nn.Module):
27+
'''
28+
Combination of MaxPool2d and DoubleConv in series
29+
'''
30+
def __init__(self, in_ch, out_ch):
31+
super().__init__()
32+
self.net = nn.Sequential(
33+
nn.MaxPool2d(kernel_size=2, stride=2),
34+
DoubleConv(in_ch, out_ch)
35+
)
36+
37+
def forward(self, x):
38+
return self.net(x)
39+
40+
41+
class Up(nn.Module):
42+
'''
43+
Upsampling (by either bilinear interpolation or transpose convolutions)
44+
followed by concatenation of feature map from contracting path,
45+
followed by double 3x3 convolution.
46+
'''
47+
def __init__(self, in_ch, out_ch, bilinear=False):
48+
super().__init__()
49+
self.upsample = None
50+
if bilinear:
51+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
52+
else:
53+
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
54+
55+
self.conv = DoubleConv(in_ch, out_ch)
56+
57+
def forward(self, x1, x2):
58+
x1 = self.upsample(x1)
59+
60+
# Pad x1 to the size of x2
61+
diff_h = x2.shape[2] - x1.shape[2]
62+
diff_w = x2.shape[3] - x1.shape[3]
63+
64+
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
65+
66+
# Concatenate along the channels axis
67+
x = torch.cat([x2, x1], dim=1)
68+
return self.conv(x)

pl_examples/full_examples/semantic_segmentation/semseg.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import torchvision
1111
import torchvision.transforms as transforms
1212
from torch.utils.data import DataLoader, Dataset
13-
from torchvision.models.segmentation import fcn_resnet50
1413

1514
import pytorch_lightning as pl
15+
from models.unet.model import UNet
1616

1717

1818
class KITTI(Dataset):
@@ -128,9 +128,7 @@ def __init__(self, hparams):
128128
self.root_path = hparams.root
129129
self.batch_size = hparams.batch_size
130130
self.learning_rate = hparams.lr
131-
self.net = torchvision.models.segmentation.fcn_resnet50(pretrained=False,
132-
progress=True,
133-
num_classes=19)
131+
self.net = UNet(num_classes=19)
134132
self.transform = transforms.Compose([
135133
transforms.ToTensor(),
136134
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
@@ -147,7 +145,7 @@ def training_step(self, batch, batch_nb):
147145
img = img.float()
148146
mask = mask.long()
149147
out = self.forward(img)
150-
loss_val = F.cross_entropy(out['out'], mask, ignore_index=250)
148+
loss_val = F.cross_entropy(out, mask, ignore_index=250)
151149
return {'loss': loss_val}
152150

153151
def configure_optimizers(self):

0 commit comments

Comments
 (0)