From 51896536d0d1d21aa057dfd84205d773fc3aba80 Mon Sep 17 00:00:00 2001 From: akshay Date: Mon, 17 Feb 2020 19:40:48 +0530 Subject: [PATCH] changed to absolute imports and added docs --- .../models/unet/__init__.py | 4 ---- .../models/unet/model.py | 19 +++++++++++-------- 2 files changed, 11 insertions(+), 12 deletions(-) delete mode 100644 pl_examples/full_examples/semantic_segmentation/models/unet/__init__.py diff --git a/pl_examples/full_examples/semantic_segmentation/models/unet/__init__.py b/pl_examples/full_examples/semantic_segmentation/models/unet/__init__.py deleted file mode 100644 index 903ed9a2e78c1..0000000000000 --- a/pl_examples/full_examples/semantic_segmentation/models/unet/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# For relative imports to work in Python 3.6 -import os -import sys -sys.path.append(os.path.dirname(os.path.realpath(__file__))) diff --git a/pl_examples/full_examples/semantic_segmentation/models/unet/model.py b/pl_examples/full_examples/semantic_segmentation/models/unet/model.py index ee87104599b0a..c83516d9a04f4 100644 --- a/pl_examples/full_examples/semantic_segmentation/models/unet/model.py +++ b/pl_examples/full_examples/semantic_segmentation/models/unet/model.py @@ -2,30 +2,33 @@ import torch.nn as nn import torch.nn.functional as F -from parts import DoubleConv, Down, Up +from models.unet.parts import DoubleConv, Down, Up class UNet(nn.Module): ''' Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation Link - https://arxiv.org/abs/1505.04597 + + Parameters: + num_classes (int) - Number of output classes required (default 19 for KITTI dataset) + bilinear (bool) - Whether to use bilinear interpolation or transposed + convolutions for upsampling. ''' def __init__(self, num_classes=19, bilinear=False): super().__init__() - self.bilinear = bilinear - self.num_classes = num_classes self.layer1 = DoubleConv(3, 64) self.layer2 = Down(64, 128) self.layer3 = Down(128, 256) self.layer4 = Down(256, 512) self.layer5 = Down(512, 1024) - self.layer6 = Up(1024, 512, bilinear=self.bilinear) - self.layer7 = Up(512, 256, bilinear=self.bilinear) - self.layer8 = Up(256, 128, bilinear=self.bilinear) - self.layer9 = Up(128, 64, bilinear=self.bilinear) + self.layer6 = Up(1024, 512, bilinear=bilinear) + self.layer7 = Up(512, 256, bilinear=bilinear) + self.layer8 = Up(256, 128, bilinear=bilinear) + self.layer9 = Up(128, 64, bilinear=bilinear) - self.layer10 = nn.Conv2d(64, self.num_classes, kernel_size=1) + self.layer10 = nn.Conv2d(64, num_classes, kernel_size=1) def forward(self, x): x1 = self.layer1(x)