|
2 | 2 | import torch.nn as nn
|
3 | 3 | import torch.nn.functional as F
|
4 | 4 |
|
5 |
| -from parts import DoubleConv, Down, Up |
| 5 | +from models.unet.parts import DoubleConv, Down, Up |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class UNet(nn.Module):
|
9 | 9 | '''
|
10 | 10 | Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation
|
11 | 11 | Link - https://arxiv.org/abs/1505.04597
|
| 12 | +
|
| 13 | + Parameters: |
| 14 | + num_classes (int) - Number of output classes required (default 19 for KITTI dataset) |
| 15 | + bilinear (bool) - Whether to use bilinear interpolation or transposed |
| 16 | + convolutions for upsampling. |
12 | 17 | '''
|
13 | 18 | def __init__(self, num_classes=19, bilinear=False):
|
14 | 19 | super().__init__()
|
15 |
| - self.bilinear = bilinear |
16 |
| - self.num_classes = num_classes |
17 | 20 | self.layer1 = DoubleConv(3, 64)
|
18 | 21 | self.layer2 = Down(64, 128)
|
19 | 22 | self.layer3 = Down(128, 256)
|
20 | 23 | self.layer4 = Down(256, 512)
|
21 | 24 | self.layer5 = Down(512, 1024)
|
22 | 25 |
|
23 |
| - self.layer6 = Up(1024, 512, bilinear=self.bilinear) |
24 |
| - self.layer7 = Up(512, 256, bilinear=self.bilinear) |
25 |
| - self.layer8 = Up(256, 128, bilinear=self.bilinear) |
26 |
| - self.layer9 = Up(128, 64, bilinear=self.bilinear) |
| 26 | + self.layer6 = Up(1024, 512, bilinear=bilinear) |
| 27 | + self.layer7 = Up(512, 256, bilinear=bilinear) |
| 28 | + self.layer8 = Up(256, 128, bilinear=bilinear) |
| 29 | + self.layer9 = Up(128, 64, bilinear=bilinear) |
27 | 30 |
|
28 |
| - self.layer10 = nn.Conv2d(64, self.num_classes, kernel_size=1) |
| 31 | + self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1) |
29 | 32 |
|
30 | 33 | def forward(self, x):
|
31 | 34 | x1 = self.layer1(x)
|
|
0 commit comments