Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(semseg): allow model customization #1338

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn

from models.unet.parts import DoubleConv, Down, Up
from .parts import DoubleConv, Down, Up


class UNet(nn.Module):
Expand All @@ -10,35 +10,37 @@ class UNet(nn.Module):

Parameters:
num_classes (int) - Number of output classes required (default 19 for KITTI dataset)
num_layers (int) - Number of layers in each side of U-net
features_start (int) - Number of features in first layer
bilinear (bool) - Whether to use bilinear interpolation or transposed
convolutions for upsampling.
convolutions for upsampling.
'''

def __init__(self, num_classes=19, bilinear=False):
def __init__(self, num_classes=19, num_layers=5, features_start=64, bilinear=False):
super().__init__()
self.layer1 = DoubleConv(3, 64)
self.layer2 = Down(64, 128)
self.layer3 = Down(128, 256)
self.layer4 = Down(256, 512)
self.layer5 = Down(512, 1024)
self.num_layers = num_layers

self.layer6 = Up(1024, 512, bilinear=bilinear)
self.layer7 = Up(512, 256, bilinear=bilinear)
self.layer8 = Up(256, 128, bilinear=bilinear)
self.layer9 = Up(128, 64, bilinear=bilinear)
layers = [DoubleConv(3, features_start)]

self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1)
feats = features_start
for _ in range(num_layers - 1):
layers.append(Down(feats, feats * 2))
feats *= 2

for _ in range(num_layers - 1):
layers.append(Up(feats, feats // 2))
feats //= 2

layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

self.layers = nn.ModuleList(layers)

def forward(self, x):
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)

x6 = self.layer6(x5, x4)
x6 = self.layer7(x6, x3)
x6 = self.layer8(x6, x2)
x6 = self.layer9(x6, x1)

return self.layer10(x6)
xi = [self.layers[0](x)]
# Down path
for layer in self.layers[1:self.num_layers]:
xi.append(layer(xi[-1]))
# Up path
for i, layer in enumerate(self.layers[self.num_layers:-1]):
xi[-1] = layer(xi[-1], xi[-2 - i])
return self.layers[-1](xi[-1])
113 changes: 78 additions & 35 deletions pl_examples/full_examples/semantic_segmentation/semseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from PIL import Image
from models.unet.model import UNet
from torch.utils.data import DataLoader, Dataset
import random

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger


class KITTI(Dataset):
Expand Down Expand Up @@ -37,8 +39,8 @@ class KITTI(Dataset):

def __init__(
self,
root_path,
split='test',
data_path,
split,
img_size=(1242, 376),
void_labels=[0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1],
valid_labels=[7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33],
Expand All @@ -49,22 +51,23 @@ def __init__(
self.valid_labels = valid_labels
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
self.split = split
self.root = root_path
if self.split == 'train':
self.img_path = os.path.join(self.root, 'training/image_2')
self.mask_path = os.path.join(self.root, 'training/semantic')
else:
self.img_path = os.path.join(self.root, 'testing/image_2')
self.mask_path = None

self.transform = transform

self.split = split
self.data_path = data_path
self.img_path = os.path.join(self.data_path, 'training/image_2')
self.mask_path = os.path.join(self.data_path, 'training/semantic')
self.img_list = self.get_filenames(self.img_path)
self.mask_list = self.get_filenames(self.mask_path)

# Split between train and valid set (80/20)
random_inst = random.Random(12345) # for repeatability
n_items = len(self.img_list)
idxs = random_inst.sample(range(n_items), n_items // 5)
if self.split == 'train':
self.mask_list = self.get_filenames(self.mask_path)
else:
self.mask_list = None
idxs = [idx for idx in range(n_items) if idx not in idxs]
self.img_list = [self.img_list[i] for i in idxs]
self.mask_list = [self.mask_list[i] for i in idxs]

def __len__(self):
return(len(self.img_list))
Expand All @@ -74,19 +77,15 @@ def __getitem__(self, idx):
img = img.resize(self.img_size)
img = np.array(img)

if self.split == 'train':
mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)
mask = Image.open(self.mask_list[idx]).convert('L')
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)

if self.transform:
img = self.transform(img)

if self.split == 'train':
return img, mask
else:
return img
return img, mask

def encode_segmap(self, mask):
'''
Expand All @@ -96,6 +95,8 @@ def encode_segmap(self, mask):
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
# remove extra idxs from updated dataset
mask[mask > 18] = self.ignore_index
return mask

def get_filenames(self, path):
Expand Down Expand Up @@ -124,17 +125,19 @@ class SegModel(pl.LightningModule):

def __init__(self, hparams):
super().__init__()
self.root_path = hparams.root
self.hparams = hparams
self.data_path = hparams.data_path
self.batch_size = hparams.batch_size
self.learning_rate = hparams.lr
self.net = UNet(num_classes=19)
self.net = UNet(num_classes=19, num_layers=hparams.num_layers,
features_start=hparams.features_start, bilinear=hparams.bilinear)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])
self.trainset = KITTI(self.root_path, split='train', transform=self.transform)
self.testset = KITTI(self.root_path, split='test', transform=self.transform)
self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
self.validset = KITTI(self.data_path, split='valid', transform=self.transform)

def forward(self, x):
return self.net(x)
Expand All @@ -145,7 +148,21 @@ def training_step(self, batch, batch_nb):
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'loss': loss_val}
log_dict = {'train_loss': loss_val}
return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}

def validation_step(self, batch, batch_idx):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'val_loss': loss_val}

def validation_epoch_end(self, outputs):
loss_val = sum(output['val_loss'] for output in outputs) / len(outputs)
log_dict = {'val_loss': loss_val}
return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict}

def configure_optimizers(self):
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
Expand All @@ -155,8 +172,8 @@ def configure_optimizers(self):
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)

def test_dataloader(self):
return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False)
def val_dataloader(self):
return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False)


def main(hparams):
Expand All @@ -166,24 +183,50 @@ def main(hparams):
model = SegModel(hparams)

# ------------------------
# 2 INIT TRAINER
# 2 SET LOGGER
# ------------------------
logger = False
if hparams.log_wandb:
logger = WandbLogger()

# optional: log model topology
logger.watch(model.net)

# ------------------------
# 3 INIT TRAINER
# ------------------------
trainer = pl.Trainer(
gpus=hparams.gpus
gpus=hparams.gpus,
logger=logger,
max_epochs=hparams.epochs,
accumulate_grad_batches=hparams.grad_batches,
distributed_backend=hparams.distributed_backend,
use_amp=hparams.use_16bit
)

# ------------------------
# 3 START TRAINING
# 5 START TRAINING
# ------------------------
trainer.fit(model)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--root", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, help="number of available GPUs")
parser.add_argument("--data-path", type=str, help="path where dataset is stored")
parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs")
parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
help='supports three options dp, ddp, ddp2')
parser.add_argument('--use-16bit', dest='use_16bit', action='store_true',
help='if true uses 16 bit precision')
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
parser.add_argument("--bilinear", type=float, default=False,
help="whether to use bilinear interpolation or transposed")
parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate")
parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train")
parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases")

hparams = parser.parse_args()

Expand Down