Skip to content

Commit 4c2a135

Browse files
borisdaymaBorda
authored andcommitted
feat(semseg): allow model customization (Lightning-AI#1371)
* feat(semantic_segmentation): allow customization of unet * feat(semseg): allow model customization * style(semseg): format to PEP8 * fix(semseg): rename logger * docs(changelog): updated semantic segmentation example * suggestions * suggestions * flake8 Co-authored-by: J. Borovec <[email protected]>
1 parent a2027f7 commit 4c2a135

File tree

3 files changed

+121
-67
lines changed

3 files changed

+121
-67
lines changed

CHANGELOG.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717

1818
- Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
1919

20+
- Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))
21+
2022
-
2123

2224
### Deprecated
2325

2426
-
2527

26-
-
27-
2828

2929
### Removed
3030

pl_examples/domain_templates/semantic_segmentation.py

+85-38
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77
import torchvision.transforms as transforms
88
from PIL import Image
99
from torch.utils.data import DataLoader, Dataset
10+
import random
1011

1112
import pytorch_lightning as pl
1213
from pl_examples.models.unet import UNet
14+
from pytorch_lightning.loggers import WandbLogger
15+
16+
DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
17+
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)
1318

1419

1520
class KITTI(Dataset):
@@ -34,37 +39,40 @@ class KITTI(Dataset):
3439
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
3540
(mask does not usually require transforms, but they can be implemented in a similar way).
3641
"""
42+
IMAGE_PATH = os.path.join('training', 'image_2')
43+
MASK_PATH = os.path.join('training', 'semantic')
3744

3845
def __init__(
3946
self,
40-
root_path,
41-
split='test',
42-
img_size=(1242, 376),
43-
void_labels=[0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1],
44-
valid_labels=[7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33],
47+
data_path: str,
48+
split: str,
49+
img_size: tuple = (1242, 376),
50+
void_labels: list = DEFAULT_VOID_LABELS,
51+
valid_labels: list = DEFAULT_VALID_LABELS,
4552
transform=None
4653
):
4754
self.img_size = img_size
4855
self.void_labels = void_labels
4956
self.valid_labels = valid_labels
5057
self.ignore_index = 250
5158
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
52-
self.split = split
53-
self.root = root_path
54-
if self.split == 'train':
55-
self.img_path = os.path.join(self.root, 'training/image_2')
56-
self.mask_path = os.path.join(self.root, 'training/semantic')
57-
else:
58-
self.img_path = os.path.join(self.root, 'testing/image_2')
59-
self.mask_path = None
60-
6159
self.transform = transform
6260

61+
self.split = split
62+
self.data_path = data_path
63+
self.img_path = os.path.join(self.data_path, self.IMAGE_PATH)
64+
self.mask_path = os.path.join(self.data_path, self.MASK_PATH)
6365
self.img_list = self.get_filenames(self.img_path)
66+
self.mask_list = self.get_filenames(self.mask_path)
67+
68+
# Split between train and valid set (80/20)
69+
random_inst = random.Random(12345) # for repeatability
70+
n_items = len(self.img_list)
71+
idxs = random_inst.sample(range(n_items), n_items // 5)
6472
if self.split == 'train':
65-
self.mask_list = self.get_filenames(self.mask_path)
66-
else:
67-
self.mask_list = None
73+
idxs = [idx for idx in range(n_items) if idx not in idxs]
74+
self.img_list = [self.img_list[i] for i in idxs]
75+
self.mask_list = [self.mask_list[i] for i in idxs]
6876

6977
def __len__(self):
7078
return len(self.img_list)
@@ -74,19 +82,15 @@ def __getitem__(self, idx):
7482
img = img.resize(self.img_size)
7583
img = np.array(img)
7684

77-
if self.split == 'train':
78-
mask = Image.open(self.mask_list[idx]).convert('L')
79-
mask = mask.resize(self.img_size)
80-
mask = np.array(mask)
81-
mask = self.encode_segmap(mask)
85+
mask = Image.open(self.mask_list[idx]).convert('L')
86+
mask = mask.resize(self.img_size)
87+
mask = np.array(mask)
88+
mask = self.encode_segmap(mask)
8289

8390
if self.transform:
8491
img = self.transform(img)
8592

86-
if self.split == 'train':
87-
return img, mask
88-
else:
89-
return img
93+
return img, mask
9094

9195
def encode_segmap(self, mask):
9296
"""
@@ -96,6 +100,8 @@ def encode_segmap(self, mask):
96100
mask[mask == voidc] = self.ignore_index
97101
for validc in self.valid_labels:
98102
mask[mask == validc] = self.class_map[validc]
103+
# remove extra idxs from updated dataset
104+
mask[mask > 18] = self.ignore_index
99105
return mask
100106

101107
def get_filenames(self, path):
@@ -124,17 +130,19 @@ class SegModel(pl.LightningModule):
124130

125131
def __init__(self, hparams):
126132
super().__init__()
127-
self.root_path = hparams.root
133+
self.hparams = hparams
134+
self.data_path = hparams.data_path
128135
self.batch_size = hparams.batch_size
129136
self.learning_rate = hparams.lr
130-
self.net = UNet(num_classes=19)
137+
self.net = UNet(num_classes=19, num_layers=hparams.num_layers,
138+
features_start=hparams.features_start, bilinear=hparams.bilinear)
131139
self.transform = transforms.Compose([
132140
transforms.ToTensor(),
133141
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
134142
std=[0.32064945, 0.32098866, 0.32325324])
135143
])
136-
self.trainset = KITTI(self.root_path, split='train', transform=self.transform)
137-
self.testset = KITTI(self.root_path, split='test', transform=self.transform)
144+
self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
145+
self.validset = KITTI(self.data_path, split='valid', transform=self.transform)
138146

139147
def forward(self, x):
140148
return self.net(x)
@@ -145,7 +153,21 @@ def training_step(self, batch, batch_nb):
145153
mask = mask.long()
146154
out = self(img)
147155
loss_val = F.cross_entropy(out, mask, ignore_index=250)
148-
return {'loss': loss_val}
156+
log_dict = {'train_loss': loss_val}
157+
return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}
158+
159+
def validation_step(self, batch, batch_idx):
160+
img, mask = batch
161+
img = img.float()
162+
mask = mask.long()
163+
out = self(img)
164+
loss_val = F.cross_entropy(out, mask, ignore_index=250)
165+
return {'val_loss': loss_val}
166+
167+
def validation_epoch_end(self, outputs):
168+
loss_val = sum(output['val_loss'] for output in outputs) / len(outputs)
169+
log_dict = {'val_loss': loss_val}
170+
return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict}
149171

150172
def configure_optimizers(self):
151173
opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)
@@ -155,8 +177,8 @@ def configure_optimizers(self):
155177
def train_dataloader(self):
156178
return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)
157179

158-
def test_dataloader(self):
159-
return DataLoader(self.testset, batch_size=self.batch_size, shuffle=False)
180+
def val_dataloader(self):
181+
return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False)
160182

161183

162184
def main(hparams):
@@ -166,24 +188,49 @@ def main(hparams):
166188
model = SegModel(hparams)
167189

168190
# ------------------------
169-
# 2 INIT TRAINER
191+
# 2 SET LOGGER
192+
# ------------------------
193+
logger = False
194+
if hparams.log_wandb:
195+
logger = WandbLogger()
196+
197+
# optional: log model topology
198+
logger.watch(model.net)
199+
200+
# ------------------------
201+
# 3 INIT TRAINER
170202
# ------------------------
171203
trainer = pl.Trainer(
172-
gpus=hparams.gpus
204+
gpus=hparams.gpus,
205+
logger=logger,
206+
max_epochs=hparams.epochs,
207+
accumulate_grad_batches=hparams.grad_batches,
208+
distributed_backend=hparams.distributed_backend,
209+
precision=16 if hparams.use_amp else 32,
173210
)
174211

175212
# ------------------------
176-
# 3 START TRAINING
213+
# 5 START TRAINING
177214
# ------------------------
178215
trainer.fit(model)
179216

180217

181218
if __name__ == '__main__':
182219
parser = ArgumentParser()
183-
parser.add_argument("--root", type=str, help="path where dataset is stored")
184-
parser.add_argument("--gpus", type=int, help="number of available GPUs")
220+
parser.add_argument("--data_path", type=str, help="path where dataset is stored")
221+
parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs")
222+
parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'),
223+
help='supports three options dp, ddp, ddp2')
224+
parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision')
185225
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
186226
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
227+
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
228+
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
229+
parser.add_argument("--bilinear", action='store_true', default=False,
230+
help="whether to use bilinear interpolation or transposed")
231+
parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate")
232+
parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train")
233+
parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases")
187234

188235
hparams = parser.parse_args()
189236

pl_examples/models/unet.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,46 @@ class UNet(nn.Module):
99
Link - https://arxiv.org/abs/1505.04597
1010
1111
Parameters:
12-
num_classes (int): Number of output classes required (default 19 for KITTI dataset)
13-
bilinear (bool): Whether to use bilinear interpolation or transposed
12+
num_classes: Number of output classes required (default 19 for KITTI dataset)
13+
num_layers: Number of layers in each side of U-net
14+
features_start: Number of features in first layer
15+
bilinear: Whether to use bilinear interpolation or transposed
1416
convolutions for upsampling.
1517
"""
1618

17-
def __init__(self, num_classes=19, bilinear=False):
19+
def __init__(
20+
self, num_classes: int = 19,
21+
num_layers: int = 5,
22+
features_start: int = 64,
23+
bilinear: bool = False
24+
):
1825
super().__init__()
19-
self.layer1 = DoubleConv(3, 64)
20-
self.layer2 = Down(64, 128)
21-
self.layer3 = Down(128, 256)
22-
self.layer4 = Down(256, 512)
23-
self.layer5 = Down(512, 1024)
26+
self.num_layers = num_layers
2427

25-
self.layer6 = Up(1024, 512, bilinear=bilinear)
26-
self.layer7 = Up(512, 256, bilinear=bilinear)
27-
self.layer8 = Up(256, 128, bilinear=bilinear)
28-
self.layer9 = Up(128, 64, bilinear=bilinear)
28+
layers = [DoubleConv(3, features_start)]
2929

30-
self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1)
30+
feats = features_start
31+
for _ in range(num_layers - 1):
32+
layers.append(Down(feats, feats * 2))
33+
feats *= 2
3134

32-
def forward(self, x):
33-
x1 = self.layer1(x)
34-
x2 = self.layer2(x1)
35-
x3 = self.layer3(x2)
36-
x4 = self.layer4(x3)
37-
x5 = self.layer5(x4)
35+
for _ in range(num_layers - 1):
36+
layers.append(Up(feats, feats // 2), bilinear)
37+
feats //= 2
38+
39+
layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))
3840

39-
x6 = self.layer6(x5, x4)
40-
x6 = self.layer7(x6, x3)
41-
x6 = self.layer8(x6, x2)
42-
x6 = self.layer9(x6, x1)
41+
self.layers = nn.ModuleList(layers)
4342

44-
return self.layer10(x6)
43+
def forward(self, x):
44+
xi = [self.layers[0](x)]
45+
# Down path
46+
for layer in self.layers[1:self.num_layers]:
47+
xi.append(layer(xi[-1]))
48+
# Up path
49+
for i, layer in enumerate(self.layers[self.num_layers:-1]):
50+
xi[-1] = layer(xi[-1], xi[-2 - i])
51+
return self.layers[-1](xi[-1])
4552

4653

4754
class DoubleConv(nn.Module):
@@ -50,7 +57,7 @@ class DoubleConv(nn.Module):
5057
(3x3 conv -> BN -> ReLU) ** 2
5158
"""
5259

53-
def __init__(self, in_ch, out_ch):
60+
def __init__(self, in_ch: int, out_ch: int):
5461
super().__init__()
5562
self.net = nn.Sequential(
5663
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
@@ -70,7 +77,7 @@ class Down(nn.Module):
7077
Combination of MaxPool2d and DoubleConv in series
7178
"""
7279

73-
def __init__(self, in_ch, out_ch):
80+
def __init__(self, in_ch: int, out_ch: int):
7481
super().__init__()
7582
self.net = nn.Sequential(
7683
nn.MaxPool2d(kernel_size=2, stride=2),
@@ -88,7 +95,7 @@ class Up(nn.Module):
8895
followed by double 3x3 convolution.
8996
"""
9097

91-
def __init__(self, in_ch, out_ch, bilinear=False):
98+
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
9299
super().__init__()
93100
self.upsample = None
94101
if bilinear:

0 commit comments

Comments
 (0)