7
7
import torchvision .transforms as transforms
8
8
from PIL import Image
9
9
from torch .utils .data import DataLoader , Dataset
10
+ import random
10
11
11
12
import pytorch_lightning as pl
12
13
from pl_examples .models .unet import UNet
14
+ from pytorch_lightning .loggers import WandbLogger
15
+
16
+ DEFAULT_VOID_LABELS = (0 , 1 , 2 , 3 , 4 , 5 , 6 , 9 , 10 , 14 , 15 , 16 , 18 , 29 , 30 , - 1 )
17
+ DEFAULT_VALID_LABELS = (7 , 8 , 11 , 12 , 13 , 17 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 31 , 32 , 33 )
13
18
14
19
15
20
class KITTI (Dataset ):
@@ -34,37 +39,40 @@ class KITTI(Dataset):
34
39
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
35
40
(mask does not usually require transforms, but they can be implemented in a similar way).
36
41
"""
42
+ IMAGE_PATH = os .path .join ('training' , 'image_2' )
43
+ MASK_PATH = os .path .join ('training' , 'semantic' )
37
44
38
45
def __init__ (
39
46
self ,
40
- root_path ,
41
- split = 'test' ,
42
- img_size = (1242 , 376 ),
43
- void_labels = [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 9 , 10 , 14 , 15 , 16 , 18 , 29 , 30 , - 1 ] ,
44
- valid_labels = [ 7 , 8 , 11 , 12 , 13 , 17 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 31 , 32 , 33 ] ,
47
+ data_path : str ,
48
+ split : str ,
49
+ img_size : tuple = (1242 , 376 ),
50
+ void_labels : list = DEFAULT_VOID_LABELS ,
51
+ valid_labels : list = DEFAULT_VALID_LABELS ,
45
52
transform = None
46
53
):
47
54
self .img_size = img_size
48
55
self .void_labels = void_labels
49
56
self .valid_labels = valid_labels
50
57
self .ignore_index = 250
51
58
self .class_map = dict (zip (self .valid_labels , range (len (self .valid_labels ))))
52
- self .split = split
53
- self .root = root_path
54
- if self .split == 'train' :
55
- self .img_path = os .path .join (self .root , 'training/image_2' )
56
- self .mask_path = os .path .join (self .root , 'training/semantic' )
57
- else :
58
- self .img_path = os .path .join (self .root , 'testing/image_2' )
59
- self .mask_path = None
60
-
61
59
self .transform = transform
62
60
61
+ self .split = split
62
+ self .data_path = data_path
63
+ self .img_path = os .path .join (self .data_path , self .IMAGE_PATH )
64
+ self .mask_path = os .path .join (self .data_path , self .MASK_PATH )
63
65
self .img_list = self .get_filenames (self .img_path )
66
+ self .mask_list = self .get_filenames (self .mask_path )
67
+
68
+ # Split between train and valid set (80/20)
69
+ random_inst = random .Random (12345 ) # for repeatability
70
+ n_items = len (self .img_list )
71
+ idxs = random_inst .sample (range (n_items ), n_items // 5 )
64
72
if self .split == 'train' :
65
- self . mask_list = self . get_filenames ( self . mask_path )
66
- else :
67
- self .mask_list = None
73
+ idxs = [ idx for idx in range ( n_items ) if idx not in idxs ]
74
+ self . img_list = [ self . img_list [ i ] for i in idxs ]
75
+ self . mask_list = [ self .mask_list [ i ] for i in idxs ]
68
76
69
77
def __len__ (self ):
70
78
return len (self .img_list )
@@ -74,19 +82,15 @@ def __getitem__(self, idx):
74
82
img = img .resize (self .img_size )
75
83
img = np .array (img )
76
84
77
- if self .split == 'train' :
78
- mask = Image .open (self .mask_list [idx ]).convert ('L' )
79
- mask = mask .resize (self .img_size )
80
- mask = np .array (mask )
81
- mask = self .encode_segmap (mask )
85
+ mask = Image .open (self .mask_list [idx ]).convert ('L' )
86
+ mask = mask .resize (self .img_size )
87
+ mask = np .array (mask )
88
+ mask = self .encode_segmap (mask )
82
89
83
90
if self .transform :
84
91
img = self .transform (img )
85
92
86
- if self .split == 'train' :
87
- return img , mask
88
- else :
89
- return img
93
+ return img , mask
90
94
91
95
def encode_segmap (self , mask ):
92
96
"""
@@ -96,6 +100,8 @@ def encode_segmap(self, mask):
96
100
mask [mask == voidc ] = self .ignore_index
97
101
for validc in self .valid_labels :
98
102
mask [mask == validc ] = self .class_map [validc ]
103
+ # remove extra idxs from updated dataset
104
+ mask [mask > 18 ] = self .ignore_index
99
105
return mask
100
106
101
107
def get_filenames (self , path ):
@@ -124,17 +130,19 @@ class SegModel(pl.LightningModule):
124
130
125
131
def __init__ (self , hparams ):
126
132
super ().__init__ ()
127
- self .root_path = hparams .root
133
+ self .hparams = hparams
134
+ self .data_path = hparams .data_path
128
135
self .batch_size = hparams .batch_size
129
136
self .learning_rate = hparams .lr
130
- self .net = UNet (num_classes = 19 )
137
+ self .net = UNet (num_classes = 19 , num_layers = hparams .num_layers ,
138
+ features_start = hparams .features_start , bilinear = hparams .bilinear )
131
139
self .transform = transforms .Compose ([
132
140
transforms .ToTensor (),
133
141
transforms .Normalize (mean = [0.35675976 , 0.37380189 , 0.3764753 ],
134
142
std = [0.32064945 , 0.32098866 , 0.32325324 ])
135
143
])
136
- self .trainset = KITTI (self .root_path , split = 'train' , transform = self .transform )
137
- self .testset = KITTI (self .root_path , split = 'test ' , transform = self .transform )
144
+ self .trainset = KITTI (self .data_path , split = 'train' , transform = self .transform )
145
+ self .validset = KITTI (self .data_path , split = 'valid ' , transform = self .transform )
138
146
139
147
def forward (self , x ):
140
148
return self .net (x )
@@ -145,7 +153,21 @@ def training_step(self, batch, batch_nb):
145
153
mask = mask .long ()
146
154
out = self (img )
147
155
loss_val = F .cross_entropy (out , mask , ignore_index = 250 )
148
- return {'loss' : loss_val }
156
+ log_dict = {'train_loss' : loss_val }
157
+ return {'loss' : loss_val , 'log' : log_dict , 'progress_bar' : log_dict }
158
+
159
+ def validation_step (self , batch , batch_idx ):
160
+ img , mask = batch
161
+ img = img .float ()
162
+ mask = mask .long ()
163
+ out = self (img )
164
+ loss_val = F .cross_entropy (out , mask , ignore_index = 250 )
165
+ return {'val_loss' : loss_val }
166
+
167
+ def validation_epoch_end (self , outputs ):
168
+ loss_val = sum (output ['val_loss' ] for output in outputs ) / len (outputs )
169
+ log_dict = {'val_loss' : loss_val }
170
+ return {'log' : log_dict , 'val_loss' : log_dict ['val_loss' ], 'progress_bar' : log_dict }
149
171
150
172
def configure_optimizers (self ):
151
173
opt = torch .optim .Adam (self .net .parameters (), lr = self .learning_rate )
@@ -155,8 +177,8 @@ def configure_optimizers(self):
155
177
def train_dataloader (self ):
156
178
return DataLoader (self .trainset , batch_size = self .batch_size , shuffle = True )
157
179
158
- def test_dataloader (self ):
159
- return DataLoader (self .testset , batch_size = self .batch_size , shuffle = False )
180
+ def val_dataloader (self ):
181
+ return DataLoader (self .validset , batch_size = self .batch_size , shuffle = False )
160
182
161
183
162
184
def main (hparams ):
@@ -166,24 +188,49 @@ def main(hparams):
166
188
model = SegModel (hparams )
167
189
168
190
# ------------------------
169
- # 2 INIT TRAINER
191
+ # 2 SET LOGGER
192
+ # ------------------------
193
+ logger = False
194
+ if hparams .log_wandb :
195
+ logger = WandbLogger ()
196
+
197
+ # optional: log model topology
198
+ logger .watch (model .net )
199
+
200
+ # ------------------------
201
+ # 3 INIT TRAINER
170
202
# ------------------------
171
203
trainer = pl .Trainer (
172
- gpus = hparams .gpus
204
+ gpus = hparams .gpus ,
205
+ logger = logger ,
206
+ max_epochs = hparams .epochs ,
207
+ accumulate_grad_batches = hparams .grad_batches ,
208
+ distributed_backend = hparams .distributed_backend ,
209
+ precision = 16 if hparams .use_amp else 32 ,
173
210
)
174
211
175
212
# ------------------------
176
- # 3 START TRAINING
213
+ # 5 START TRAINING
177
214
# ------------------------
178
215
trainer .fit (model )
179
216
180
217
181
218
if __name__ == '__main__' :
182
219
parser = ArgumentParser ()
183
- parser .add_argument ("--root" , type = str , help = "path where dataset is stored" )
184
- parser .add_argument ("--gpus" , type = int , help = "number of available GPUs" )
220
+ parser .add_argument ("--data_path" , type = str , help = "path where dataset is stored" )
221
+ parser .add_argument ("--gpus" , type = int , default = - 1 , help = "number of available GPUs" )
222
+ parser .add_argument ('--distributed-backend' , type = str , default = 'dp' , choices = ('dp' , 'ddp' , 'ddp2' ),
223
+ help = 'supports three options dp, ddp, ddp2' )
224
+ parser .add_argument ('--use_amp' , action = 'store_true' , help = 'if true uses 16 bit precision' )
185
225
parser .add_argument ("--batch_size" , type = int , default = 4 , help = "size of the batches" )
186
226
parser .add_argument ("--lr" , type = float , default = 0.001 , help = "adam: learning rate" )
227
+ parser .add_argument ("--num_layers" , type = int , default = 5 , help = "number of layers on u-net" )
228
+ parser .add_argument ("--features_start" , type = float , default = 64 , help = "number of features in first layer" )
229
+ parser .add_argument ("--bilinear" , action = 'store_true' , default = False ,
230
+ help = "whether to use bilinear interpolation or transposed" )
231
+ parser .add_argument ("--grad_batches" , type = int , default = 1 , help = "number of batches to accumulate" )
232
+ parser .add_argument ("--epochs" , type = int , default = 20 , help = "number of epochs to train" )
233
+ parser .add_argument ("--log_wandb" , action = 'store_true' , help = "log training on Weights & Biases" )
187
234
188
235
hparams = parser .parse_args ()
189
236
0 commit comments