From 1b856cd13b1736242bfe48b0be037ad7b9f9322f Mon Sep 17 00:00:00 2001 From: akshay Date: Sat, 25 Jan 2020 19:29:31 +0530 Subject: [PATCH 1/7] added initial semantic segmentation example --- .../semantic_segmentation/semseg.py | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 pl_examples/full_examples/semantic_segmentation/semseg.py diff --git a/pl_examples/full_examples/semantic_segmentation/semseg.py b/pl_examples/full_examples/semantic_segmentation/semseg.py new file mode 100644 index 0000000000000..4749e1e4df84f --- /dev/null +++ b/pl_examples/full_examples/semantic_segmentation/semseg.py @@ -0,0 +1,152 @@ +import os +from argparse import ArgumentParser +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader, Dataset +from torchvision.models.segmentation import fcn_resnet50 + +from PIL import Image +import pytorch_lightning as pl + +class KITTI(Dataset): + def __init__(self, root_path, split = 'test', img_size = (1242, 376), transform = None): + self.img_size = img_size + self.void_labels = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] + self.valid_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] + self.ignore_index = 250 + self.class_map = dict(zip(self.valid_labels, range(19))) + self.split = split + self.root = root_path + if self.split == 'train': + self.img_path = os.path.join(self.root, 'training/image_2') + self.mask_path = os.path.join(self.root, 'training/semantic') + else : + self.img_path = os.path.join(self.root, 'testing/image_2') + self.mask_path = None + + self.transform = transform + + self.img_list = self.get_filenames(self.img_path) + if self.split == 'train': + self.mask_list = self.get_filenames(self.mask_path) + else : + self.mask_list = None + + def __len__(self): + return(len(self.img_list)) + + def __getitem__(self, idx): + img = Image.open(self.img_list[idx]) + img = img.resize(self.img_size) + img = np.array(img) + + if self.split == 'train' : + mask = Image.open(self.mask_list[idx]).convert('L') + mask = mask.resize(self.img_size) + mask = np.array(mask) + mask = self.encode_segmap(mask) + + if self.transform : + img = self.transform(img) + + if self.split == 'train' : + return img, mask + else : + return img + + def encode_segmap(self, mask): + ''' + Sets void classes to zero so they won't be considered for training + ''' + for voidc in self.void_labels : + mask[mask == voidc] = self.ignore_index + for validc in self.valid_labels : + mask[mask == validc] = self.class_map[validc] + return mask + + def get_filenames(self, path): + files_list = list() + for filename in os.listdir(path): + files_list.append(os.path.join(path, filename)) + return files_list + +class SegModel(pl.LightningModule): + def __init__(self, hparams): + super(SegModel, self).__init__() + self.hparams = hparams +# self.root_path = '/home/akshay/Projects/pl-sem-seg/' + self.root_path = hparams.root + self.batch_size = hparams.batch_size + self.learning_rate = hparams.lr + self.net = torchvision.models.segmentation.fcn_resnet50(pretrained = False, progress = True, num_classes = 19) + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean = [0.35675976, 0.37380189, 0.3764753], std = [0.32064945, 0.32098866, 0.32325324]) + ]) + self.trainset = KITTI(self.root_path, split = 'train', transform = self.transform) + self.testset = KITTI(self.root_path, split = 'test', transform = self.transform) + + def forward(self, x): + return self.net(x) + + def training_step(self, batch, batch_nb) : + img, mask = batch + img = img.float() + mask = mask.long() + out = self.forward(img) + loss_val = F.cross_entropy(out['out'], mask, ignore_index = 250) + return {'loss' : loss_val} + +# def test_step(self, batch, batch_nb): +# print('-----------------testing-----------------') +# img = batch +# # log sampled images +# masks = self.net(img) +# grid = torchvision.utils.make_grid(masks) +# self.logger.experiment.add_image(f'generated_images', grid, self.current_epoch) + + def configure_optimizers(self): + opt = torch.optim.Adam(self.net.parameters(), lr = self.learning_rate) + sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max = 10) + return [opt], [sch] + + def train_dataloader(self): + return DataLoader(self.trainset, batch_size = self.batch_size, shuffle = True) + + def test_dataloader(self): + return DataLoader(self.testset, batch_size = 1, shuffle = True) + +def main(hparams): + # ------------------------ + # 1 INIT LIGHTNING MODEL + # ------------------------ + model = SegModel(hparams) + + # ------------------------ + # 2 INIT TRAINER + # ------------------------ + trainer = pl.Trainer( + gpus = hparams.gpus + ) + + # ------------------------ + # 3 START TRAINING + # ------------------------ + trainer.fit(model) + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument("--root", type = str, help = "path where dataset is stored") + parser.add_argument("--gpus", type = int, help = "number of available GPUs") + parser.add_argument("--batch_size", type = int, default = 4, help = "size of the batches") + parser.add_argument("--lr", type = float, default = 0.001, help = "adam: learning rate") + + hparams = parser.parse_args() + + main(hparams) From c2bd71032b19303731fe0f63373ca88db29c193e Mon Sep 17 00:00:00 2001 From: akshay Date: Sat, 25 Jan 2020 19:52:03 +0530 Subject: [PATCH 2/7] removed unnecessary lines. --- pl_examples/full_examples/semantic_segmentation/semseg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pl_examples/full_examples/semantic_segmentation/semseg.py b/pl_examples/full_examples/semantic_segmentation/semseg.py index 4749e1e4df84f..592e0192babbd 100644 --- a/pl_examples/full_examples/semantic_segmentation/semseg.py +++ b/pl_examples/full_examples/semantic_segmentation/semseg.py @@ -80,7 +80,6 @@ class SegModel(pl.LightningModule): def __init__(self, hparams): super(SegModel, self).__init__() self.hparams = hparams -# self.root_path = '/home/akshay/Projects/pl-sem-seg/' self.root_path = hparams.root self.batch_size = hparams.batch_size self.learning_rate = hparams.lr From a7889d46088374222f60368fcade4732a95f6ff1 Mon Sep 17 00:00:00 2001 From: akshay Date: Fri, 14 Feb 2020 22:40:12 +0530 Subject: [PATCH 3/7] changed according to reviews --- .../semantic_segmentation/semseg.py | 102 +++++++++--------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/pl_examples/full_examples/semantic_segmentation/semseg.py b/pl_examples/full_examples/semantic_segmentation/semseg.py index 592e0192babbd..83ec2d253c367 100644 --- a/pl_examples/full_examples/semantic_segmentation/semseg.py +++ b/pl_examples/full_examples/semantic_segmentation/semseg.py @@ -14,8 +14,9 @@ from PIL import Image import pytorch_lightning as pl + class KITTI(Dataset): - def __init__(self, root_path, split = 'test', img_size = (1242, 376), transform = None): + def __init__(self, root_path, split='test', img_size=(1242, 376), transform=None): self.img_size = img_size self.void_labels = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] self.valid_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] @@ -24,103 +25,99 @@ def __init__(self, root_path, split = 'test', img_size = (1242, 376), transform self.split = split self.root = root_path if self.split == 'train': - self.img_path = os.path.join(self.root, 'training/image_2') + self.img_path = os.path.join(self.root, 'training/image_2') self.mask_path = os.path.join(self.root, 'training/semantic') - else : + else: self.img_path = os.path.join(self.root, 'testing/image_2') self.mask_path = None self.transform = transform - + self.img_list = self.get_filenames(self.img_path) if self.split == 'train': self.mask_list = self.get_filenames(self.mask_path) - else : + else: self.mask_list = None - + def __len__(self): return(len(self.img_list)) - + def __getitem__(self, idx): img = Image.open(self.img_list[idx]) img = img.resize(self.img_size) img = np.array(img) - - if self.split == 'train' : + + if self.split == 'train': mask = Image.open(self.mask_list[idx]).convert('L') mask = mask.resize(self.img_size) mask = np.array(mask) mask = self.encode_segmap(mask) - - if self.transform : + + if self.transform: img = self.transform(img) - - if self.split == 'train' : + + if self.split == 'train': return img, mask - else : + else: return img - + def encode_segmap(self, mask): ''' Sets void classes to zero so they won't be considered for training ''' - for voidc in self.void_labels : + for voidc in self.void_labels: mask[mask == voidc] = self.ignore_index - for validc in self.valid_labels : + for validc in self.valid_labels: mask[mask == validc] = self.class_map[validc] return mask - + def get_filenames(self, path): + ''' + Returns a list of absolute paths to images inside given `path` + ''' files_list = list() for filename in os.listdir(path): files_list.append(os.path.join(path, filename)) return files_list - + + class SegModel(pl.LightningModule): def __init__(self, hparams): super(SegModel, self).__init__() - self.hparams = hparams self.root_path = hparams.root self.batch_size = hparams.batch_size self.learning_rate = hparams.lr - self.net = torchvision.models.segmentation.fcn_resnet50(pretrained = False, progress = True, num_classes = 19) + self.net = torchvision.models.segmentation.fcn_resnet50(pretrained=False, progress=True, num_classes=19) self.transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize(mean = [0.35675976, 0.37380189, 0.3764753], std = [0.32064945, 0.32098866, 0.32325324]) + transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) ]) - self.trainset = KITTI(self.root_path, split = 'train', transform = self.transform) - self.testset = KITTI(self.root_path, split = 'test', transform = self.transform) - + self.trainset = KITTI(self.root_path, split='train', transform=self.transform) + self.testset = KITTI(self.root_path, split='test', transform=self.transform) + def forward(self, x): return self.net(x) - - def training_step(self, batch, batch_nb) : + + def training_step(self, batch, batch_nb): img, mask = batch img = img.float() mask = mask.long() out = self.forward(img) - loss_val = F.cross_entropy(out['out'], mask, ignore_index = 250) - return {'loss' : loss_val} - -# def test_step(self, batch, batch_nb): -# print('-----------------testing-----------------') -# img = batch -# # log sampled images -# masks = self.net(img) -# grid = torchvision.utils.make_grid(masks) -# self.logger.experiment.add_image(f'generated_images', grid, self.current_epoch) - + loss_val = F.cross_entropy(out['out'], mask, ignore_index=250) + return {'loss': loss_val} + def configure_optimizers(self): - opt = torch.optim.Adam(self.net.parameters(), lr = self.learning_rate) - sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max = 10) + opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate) + sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10) return [opt], [sch] - + def train_dataloader(self): - return DataLoader(self.trainset, batch_size = self.batch_size, shuffle = True) - + return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True) + def test_dataloader(self): - return DataLoader(self.testset, batch_size = 1, shuffle = True) - + return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False) + + def main(hparams): # ------------------------ # 1 INIT LIGHTNING MODEL @@ -131,20 +128,23 @@ def main(hparams): # 2 INIT TRAINER # ------------------------ trainer = pl.Trainer( - gpus = hparams.gpus + gpus=hparams.gpus ) # ------------------------ # 3 START TRAINING # ------------------------ - trainer.fit(model) + # trainer.fit(model) + + trainer.test(model) + if __name__ == '__main__': parser = ArgumentParser() - parser.add_argument("--root", type = str, help = "path where dataset is stored") - parser.add_argument("--gpus", type = int, help = "number of available GPUs") - parser.add_argument("--batch_size", type = int, default = 4, help = "size of the batches") - parser.add_argument("--lr", type = float, default = 0.001, help = "adam: learning rate") + parser.add_argument("--root", type=str, help="path where dataset is stored") + parser.add_argument("--gpus", type=int, help="number of available GPUs") + parser.add_argument("--batch_size", type=int, default=4, help="size of the batches") + parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") hparams = parser.parse_args() From a2ada020cde868ecfdc85aed07d4d279b49199b4 Mon Sep 17 00:00:00 2001 From: akshay Date: Fri, 14 Feb 2020 22:42:27 +0530 Subject: [PATCH 4/7] minor changes --- pl_examples/full_examples/semantic_segmentation/semseg.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pl_examples/full_examples/semantic_segmentation/semseg.py b/pl_examples/full_examples/semantic_segmentation/semseg.py index 83ec2d253c367..965a1e6a2bf32 100644 --- a/pl_examples/full_examples/semantic_segmentation/semseg.py +++ b/pl_examples/full_examples/semantic_segmentation/semseg.py @@ -134,9 +134,7 @@ def main(hparams): # ------------------------ # 3 START TRAINING # ------------------------ - # trainer.fit(model) - - trainer.test(model) + trainer.fit(model) if __name__ == '__main__': From aec1d02ff01aecce0ff1f2ad0c0d390436e48045 Mon Sep 17 00:00:00 2001 From: akshay Date: Sun, 16 Feb 2020 12:30:35 +0530 Subject: [PATCH 5/7] Added some documentation for Dataset class --- .../semantic_segmentation/semseg.py | 43 ++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/pl_examples/full_examples/semantic_segmentation/semseg.py b/pl_examples/full_examples/semantic_segmentation/semseg.py index 965a1e6a2bf32..8eebf81f46e0c 100644 --- a/pl_examples/full_examples/semantic_segmentation/semseg.py +++ b/pl_examples/full_examples/semantic_segmentation/semseg.py @@ -1,6 +1,7 @@ import os from argparse import ArgumentParser from collections import OrderedDict +from PIL import Image import numpy as np import torch @@ -11,17 +12,45 @@ from torch.utils.data import DataLoader, Dataset from torchvision.models.segmentation import fcn_resnet50 -from PIL import Image import pytorch_lightning as pl class KITTI(Dataset): - def __init__(self, root_path, split='test', img_size=(1242, 376), transform=None): + ''' + Dataset Class for KITTI Semantic Segmentation Benchmark dataset + Dataset link - http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 + + There are 34 classes in the given labels. However, not all of them are useful for training + (like railings on highways, road dividers, etc.). + So, these useless classes (the pixel values of these classes) are stored in the `void_labels`. + The useful classes are stored in the `valid_labels`. + + The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index` + (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and + `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by + the loss function when comparing with the output. + + The `get_filenames` function retrieves the filenames of all images in the given `path` and + saves the absolute path in a list. + + In the `get_item` function, images and masks are resized to the given `img_size`, masks are + encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only + (mask does not usually require transforms, but they can be implemented in a similar way). + ''' + def __init__( + self, + root_path, + split='test', + img_size=(1242, 376), + void_labels=[0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1], + valid_labels=[7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33], + transform=None + ): self.img_size = img_size - self.void_labels = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] - self.valid_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] + self.void_labels = void_labels + self.valid_labels = valid_labels self.ignore_index = 250 - self.class_map = dict(zip(self.valid_labels, range(19))) + self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels)))) self.split = split self.root = root_path if self.split == 'train': @@ -128,7 +157,9 @@ def main(hparams): # 2 INIT TRAINER # ------------------------ trainer = pl.Trainer( - gpus=hparams.gpus + gpus=hparams.gpus, + max_nb_epochs=5, + early_stop_callback=None ) # ------------------------ From 7adfdff25b205435e79b9e5fa77b5ace82a577a0 Mon Sep 17 00:00:00 2001 From: akshay Date: Sun, 16 Feb 2020 12:31:53 +0530 Subject: [PATCH 6/7] Fixed some long lines --- pl_examples/full_examples/semantic_segmentation/semseg.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pl_examples/full_examples/semantic_segmentation/semseg.py b/pl_examples/full_examples/semantic_segmentation/semseg.py index 8eebf81f46e0c..89331cf397cdd 100644 --- a/pl_examples/full_examples/semantic_segmentation/semseg.py +++ b/pl_examples/full_examples/semantic_segmentation/semseg.py @@ -116,10 +116,13 @@ def __init__(self, hparams): self.root_path = hparams.root self.batch_size = hparams.batch_size self.learning_rate = hparams.lr - self.net = torchvision.models.segmentation.fcn_resnet50(pretrained=False, progress=True, num_classes=19) + self.net = torchvision.models.segmentation.fcn_resnet50(pretrained=False, + progress=True, + num_classes=19) self.transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) + transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], + std=[0.32064945, 0.32098866, 0.32325324]) ]) self.trainset = KITTI(self.root_path, split='train', transform=self.transform) self.testset = KITTI(self.root_path, split='test', transform=self.transform) From e615f6d286b8b6e43b02202ade115feda3ff70d2 Mon Sep 17 00:00:00 2001 From: akshay Date: Sun, 16 Feb 2020 18:22:22 +0530 Subject: [PATCH 7/7] added docstring for LightningModule --- .../semantic_segmentation/semseg.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pl_examples/full_examples/semantic_segmentation/semseg.py b/pl_examples/full_examples/semantic_segmentation/semseg.py index 89331cf397cdd..c763a39f1f64c 100644 --- a/pl_examples/full_examples/semantic_segmentation/semseg.py +++ b/pl_examples/full_examples/semantic_segmentation/semseg.py @@ -111,6 +111,18 @@ def get_filenames(self, path): class SegModel(pl.LightningModule): + ''' + Semantic Segmentation Module + + This is a basic semantic segmentation module implemented with Lightning. + It uses CrossEntropyLoss as the default loss function. May be replaced with + other loss functions as required. + It is specific to KITTI dataset i.e. dataloaders are for KITTI + and Normalize transform uses the mean and standard deviation of this dataset. + It uses the FCN ResNet50 model as an example. + + Adam optimizer is used along with Cosine Annealing learning rate scheduler. + ''' def __init__(self, hparams): super(SegModel, self).__init__() self.root_path = hparams.root @@ -160,9 +172,7 @@ def main(hparams): # 2 INIT TRAINER # ------------------------ trainer = pl.Trainer( - gpus=hparams.gpus, - max_nb_epochs=5, - early_stop_callback=None + gpus=hparams.gpus ) # ------------------------