Skip to content

Commit 11db9d2

Browse files
added initial semantic segmentation example (#751)
* added initial semantic segmentation example * removed unnecessary lines. * changed according to reviews * minor changes * Added some documentation for Dataset class * Fixed some long lines * added docstring for LightningModule
1 parent d3d7e7b commit 11db9d2

File tree

1 file changed

+193
-0
lines changed
  • pl_examples/full_examples/semantic_segmentation

1 file changed

+193
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import os
2+
from argparse import ArgumentParser
3+
from collections import OrderedDict
4+
from PIL import Image
5+
6+
import numpy as np
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
import torchvision
11+
import torchvision.transforms as transforms
12+
from torch.utils.data import DataLoader, Dataset
13+
from torchvision.models.segmentation import fcn_resnet50
14+
15+
import pytorch_lightning as pl
16+
17+
18+
class KITTI(Dataset):
19+
'''
20+
Dataset Class for KITTI Semantic Segmentation Benchmark dataset
21+
Dataset link - http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
22+
23+
There are 34 classes in the given labels. However, not all of them are useful for training
24+
(like railings on highways, road dividers, etc.).
25+
So, these useless classes (the pixel values of these classes) are stored in the `void_labels`.
26+
The useful classes are stored in the `valid_labels`.
27+
28+
The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
29+
(250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
30+
`len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
31+
the loss function when comparing with the output.
32+
33+
The `get_filenames` function retrieves the filenames of all images in the given `path` and
34+
saves the absolute path in a list.
35+
36+
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
37+
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
38+
(mask does not usually require transforms, but they can be implemented in a similar way).
39+
'''
40+
def __init__(
41+
self,
42+
root_path,
43+
split='test',
44+
img_size=(1242, 376),
45+
void_labels=[0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1],
46+
valid_labels=[7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33],
47+
transform=None
48+
):
49+
self.img_size = img_size
50+
self.void_labels = void_labels
51+
self.valid_labels = valid_labels
52+
self.ignore_index = 250
53+
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
54+
self.split = split
55+
self.root = root_path
56+
if self.split == 'train':
57+
self.img_path = os.path.join(self.root, 'training/image_2')
58+
self.mask_path = os.path.join(self.root, 'training/semantic')
59+
else:
60+
self.img_path = os.path.join(self.root, 'testing/image_2')
61+
self.mask_path = None
62+
63+
self.transform = transform
64+
65+
self.img_list = self.get_filenames(self.img_path)
66+
if self.split == 'train':
67+
self.mask_list = self.get_filenames(self.mask_path)
68+
else:
69+
self.mask_list = None
70+
71+
def __len__(self):
72+
return(len(self.img_list))
73+
74+
def __getitem__(self, idx):
75+
img = Image.open(self.img_list[idx])
76+
img = img.resize(self.img_size)
77+
img = np.array(img)
78+
79+
if self.split == 'train':
80+
mask = Image.open(self.mask_list[idx]).convert('L')
81+
mask = mask.resize(self.img_size)
82+
mask = np.array(mask)
83+
mask = self.encode_segmap(mask)
84+
85+
if self.transform:
86+
img = self.transform(img)
87+
88+
if self.split == 'train':
89+
return img, mask
90+
else:
91+
return img
92+
93+
def encode_segmap(self, mask):
94+
'''
95+
Sets void classes to zero so they won't be considered for training
96+
'''
97+
for voidc in self.void_labels:
98+
mask[mask == voidc] = self.ignore_index
99+
for validc in self.valid_labels:
100+
mask[mask == validc] = self.class_map[validc]
101+
return mask
102+
103+
def get_filenames(self, path):
104+
'''
105+
Returns a list of absolute paths to images inside given `path`
106+
'''
107+
files_list = list()
108+
for filename in os.listdir(path):
109+
files_list.append(os.path.join(path, filename))
110+
return files_list
111+
112+
113+
class SegModel(pl.LightningModule):
114+
'''
115+
Semantic Segmentation Module
116+
117+
This is a basic semantic segmentation module implemented with Lightning.
118+
It uses CrossEntropyLoss as the default loss function. May be replaced with
119+
other loss functions as required.
120+
It is specific to KITTI dataset i.e. dataloaders are for KITTI
121+
and Normalize transform uses the mean and standard deviation of this dataset.
122+
It uses the FCN ResNet50 model as an example.
123+
124+
Adam optimizer is used along with Cosine Annealing learning rate scheduler.
125+
'''
126+
def __init__(self, hparams):
127+
super(SegModel, self).__init__()
128+
self.root_path = hparams.root
129+
self.batch_size = hparams.batch_size
130+
self.learning_rate = hparams.lr
131+
self.net = torchvision.models.segmentation.fcn_resnet50(pretrained=False,
132+
progress=True,
133+
num_classes=19)
134+
self.transform = transforms.Compose([
135+
transforms.ToTensor(),
136+
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
137+
std=[0.32064945, 0.32098866, 0.32325324])
138+
])
139+
self.trainset = KITTI(self.root_path, split='train', transform=self.transform)
140+
self.testset = KITTI(self.root_path, split='test', transform=self.transform)
141+
142+
def forward(self, x):
143+
return self.net(x)
144+
145+
def training_step(self, batch, batch_nb):
146+
img, mask = batch
147+
img = img.float()
148+
mask = mask.long()
149+
out = self.forward(img)
150+
loss_val = F.cross_entropy(out['out'], mask, ignore_index=250)
151+
return {'loss': loss_val}
152+
153+
def configure_optimizers(self):
154+
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
155+
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
156+
return [opt], [sch]
157+
158+
def train_dataloader(self):
159+
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)
160+
161+
def test_dataloader(self):
162+
return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False)
163+
164+
165+
def main(hparams):
166+
# ------------------------
167+
# 1 INIT LIGHTNING MODEL
168+
# ------------------------
169+
model = SegModel(hparams)
170+
171+
# ------------------------
172+
# 2 INIT TRAINER
173+
# ------------------------
174+
trainer = pl.Trainer(
175+
gpus=hparams.gpus
176+
)
177+
178+
# ------------------------
179+
# 3 START TRAINING
180+
# ------------------------
181+
trainer.fit(model)
182+
183+
184+
if __name__ == '__main__':
185+
parser = ArgumentParser()
186+
parser.add_argument("--root", type=str, help="path where dataset is stored")
187+
parser.add_argument("--gpus", type=int, help="number of available GPUs")
188+
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
189+
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
190+
191+
hparams = parser.parse_args()
192+
193+
main(hparams)

0 commit comments

Comments
 (0)