Skip to content

Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation

License

Notifications You must be signed in to change notification settings

jamesjjcondon/pytorch-unet

This branch is 4 commits behind usuyama/pytorch-unet:master.

Folders and files

NameName
Last commit message
Last commit date

Latest commit

063cf7b · Jun 27, 2019

History

15 Commits
Oct 7, 2018
Jun 27, 2018
Nov 26, 2018
Jun 27, 2019
Jun 27, 2018
Jul 16, 2018
Jun 27, 2018
Dec 30, 2018
Jul 16, 2018
Jul 16, 2018
Jul 16, 2018

Repository files navigation

UNet/FCN PyTorch

This repository contains simple PyTorch implementations of U-Net and FCN, which are deep learning segmentation methods proposed by Ronneberger et al. and Long et al.

Synthetic images/masks for training

First clone the repository and cd into the project directory.

import matplotlib.pyplot as plt
import numpy as np
import helper
import simulation

# Generate some random images
input_images, target_masks = simulation.generate_random_data(192, 192, count=3)

for x in [input_images, target_masks]:
    print(x.shape)
    print(x.min(), x.max())

# Change channel-order and make 3 channels for matplot
input_images_rgb = [x.astype(np.uint8) for x in input_images]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]

# Left: Input image (black and white), Right: Target mask (6ch)
helper.plot_side_by_side([input_images_rgb, target_masks_rgb])

Left: Input image (black and white), Right: Target mask (6ch)

png

Prepare Dataset and DataLoader

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

class SimDataset(Dataset):
    def __init__(self, count, transform=None):
        self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count)
        self.transform = transform

    def __len__(self):
        return len(self.input_images)

    def __getitem__(self, idx):
        image = self.input_images[idx]
        mask = self.target_masks[idx]
        if self.transform:
            image = self.transform(image)

        return [image, mask]

# use the same transformations for train/val in this example
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
])

train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)

image_datasets = {
    'train': train_set, 'val': val_set
}

batch_size = 25

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

Check the outputs from DataLoader

import torchvision.utils

def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)

    return inp

# Get a batch of training data
inputs, masks = next(iter(dataloaders['train']))

print(inputs.shape, masks.shape)

plt.imshow(reverse_transform(inputs[3]))
torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192])

png

Create the UNet module

import torch
import torch.nn as nn
from torchvision import models

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

Model summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetUNet(n_class=6)
model = model.to(device)

# check keras-like model summary using torchsummary
from torchsummary import summary
summary(model, input_size=(3, 224, 224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
            Conv2d-5         [-1, 64, 112, 112]           9,408
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
         MaxPool2d-8           [-1, 64, 56, 56]               0
            Conv2d-9           [-1, 64, 56, 56]           4,096
      BatchNorm2d-10           [-1, 64, 56, 56]             128
             ReLU-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]          16,384
      BatchNorm2d-16          [-1, 256, 56, 56]             512
           Conv2d-17          [-1, 256, 56, 56]          16,384
      BatchNorm2d-18          [-1, 256, 56, 56]             512
             ReLU-19          [-1, 256, 56, 56]               0
       Bottleneck-20          [-1, 256, 56, 56]               0
           Conv2d-21           [-1, 64, 56, 56]          16,384
      BatchNorm2d-22           [-1, 64, 56, 56]             128
             ReLU-23           [-1, 64, 56, 56]               0
           Conv2d-24           [-1, 64, 56, 56]          36,864
      BatchNorm2d-25           [-1, 64, 56, 56]             128
             ReLU-26           [-1, 64, 56, 56]               0
           Conv2d-27          [-1, 256, 56, 56]          16,384
      BatchNorm2d-28          [-1, 256, 56, 56]             512
             ReLU-29          [-1, 256, 56, 56]               0
       Bottleneck-30          [-1, 256, 56, 56]               0
           Conv2d-31           [-1, 64, 56, 56]          16,384
      BatchNorm2d-32           [-1, 64, 56, 56]             128
             ReLU-33           [-1, 64, 56, 56]               0
           Conv2d-34           [-1, 64, 56, 56]          36,864
      BatchNorm2d-35           [-1, 64, 56, 56]             128
             ReLU-36           [-1, 64, 56, 56]               0
           Conv2d-37          [-1, 256, 56, 56]          16,384
      BatchNorm2d-38          [-1, 256, 56, 56]             512
             ReLU-39          [-1, 256, 56, 56]               0
       Bottleneck-40          [-1, 256, 56, 56]               0
           Conv2d-41          [-1, 128, 56, 56]          32,768
      BatchNorm2d-42          [-1, 128, 56, 56]             256
             ReLU-43          [-1, 128, 56, 56]               0
           Conv2d-44          [-1, 128, 28, 28]         147,456
      BatchNorm2d-45          [-1, 128, 28, 28]             256
             ReLU-46          [-1, 128, 28, 28]               0
           Conv2d-47          [-1, 512, 28, 28]          65,536
      BatchNorm2d-48          [-1, 512, 28, 28]           1,024
           Conv2d-49          [-1, 512, 28, 28]         131,072
      BatchNorm2d-50          [-1, 512, 28, 28]           1,024
             ReLU-51          [-1, 512, 28, 28]               0
       Bottleneck-52          [-1, 512, 28, 28]               0
           Conv2d-53          [-1, 128, 28, 28]          65,536
      BatchNorm2d-54          [-1, 128, 28, 28]             256
             ReLU-55          [-1, 128, 28, 28]               0
           Conv2d-56          [-1, 128, 28, 28]         147,456
      BatchNorm2d-57          [-1, 128, 28, 28]             256
             ReLU-58          [-1, 128, 28, 28]               0
           Conv2d-59          [-1, 512, 28, 28]          65,536
      BatchNorm2d-60          [-1, 512, 28, 28]           1,024
             ReLU-61          [-1, 512, 28, 28]               0
       Bottleneck-62          [-1, 512, 28, 28]               0
           Conv2d-63          [-1, 128, 28, 28]          65,536
      BatchNorm2d-64          [-1, 128, 28, 28]             256
             ReLU-65          [-1, 128, 28, 28]               0
           Conv2d-66          [-1, 128, 28, 28]         147,456
      BatchNorm2d-67          [-1, 128, 28, 28]             256
             ReLU-68          [-1, 128, 28, 28]               0
           Conv2d-69          [-1, 512, 28, 28]          65,536
      BatchNorm2d-70          [-1, 512, 28, 28]           1,024
             ReLU-71          [-1, 512, 28, 28]               0
       Bottleneck-72          [-1, 512, 28, 28]               0
           Conv2d-73          [-1, 128, 28, 28]          65,536
      BatchNorm2d-74          [-1, 128, 28, 28]             256
             ReLU-75          [-1, 128, 28, 28]               0
           Conv2d-76          [-1, 128, 28, 28]         147,456
      BatchNorm2d-77          [-1, 128, 28, 28]             256
             ReLU-78          [-1, 128, 28, 28]               0
           Conv2d-79          [-1, 512, 28, 28]          65,536
      BatchNorm2d-80          [-1, 512, 28, 28]           1,024
             ReLU-81          [-1, 512, 28, 28]               0
       Bottleneck-82          [-1, 512, 28, 28]               0
           Conv2d-83          [-1, 256, 28, 28]         131,072
      BatchNorm2d-84          [-1, 256, 28, 28]             512
             ReLU-85          [-1, 256, 28, 28]               0
           Conv2d-86          [-1, 256, 14, 14]         589,824
      BatchNorm2d-87          [-1, 256, 14, 14]             512
             ReLU-88          [-1, 256, 14, 14]               0
           Conv2d-89         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-90         [-1, 1024, 14, 14]           2,048
           Conv2d-91         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-92         [-1, 1024, 14, 14]           2,048
             ReLU-93         [-1, 1024, 14, 14]               0
       Bottleneck-94         [-1, 1024, 14, 14]               0
           Conv2d-95          [-1, 256, 14, 14]         262,144
      BatchNorm2d-96          [-1, 256, 14, 14]             512
             ReLU-97          [-1, 256, 14, 14]               0
           Conv2d-98          [-1, 256, 14, 14]         589,824
      BatchNorm2d-99          [-1, 256, 14, 14]             512
            ReLU-100          [-1, 256, 14, 14]               0
          Conv2d-101         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-102         [-1, 1024, 14, 14]           2,048
            ReLU-103         [-1, 1024, 14, 14]               0
      Bottleneck-104         [-1, 1024, 14, 14]               0
          Conv2d-105          [-1, 256, 14, 14]         262,144
     BatchNorm2d-106          [-1, 256, 14, 14]             512
            ReLU-107          [-1, 256, 14, 14]               0
          Conv2d-108          [-1, 256, 14, 14]         589,824
     BatchNorm2d-109          [-1, 256, 14, 14]             512
            ReLU-110          [-1, 256, 14, 14]               0
          Conv2d-111         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-112         [-1, 1024, 14, 14]           2,048
            ReLU-113         [-1, 1024, 14, 14]               0
      Bottleneck-114         [-1, 1024, 14, 14]               0
          Conv2d-115          [-1, 256, 14, 14]         262,144
     BatchNorm2d-116          [-1, 256, 14, 14]             512
            ReLU-117          [-1, 256, 14, 14]               0
          Conv2d-118          [-1, 256, 14, 14]         589,824
     BatchNorm2d-119          [-1, 256, 14, 14]             512
            ReLU-120          [-1, 256, 14, 14]               0
          Conv2d-121         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-122         [-1, 1024, 14, 14]           2,048
            ReLU-123         [-1, 1024, 14, 14]               0
      Bottleneck-124         [-1, 1024, 14, 14]               0
          Conv2d-125          [-1, 256, 14, 14]         262,144
     BatchNorm2d-126          [-1, 256, 14, 14]             512
            ReLU-127          [-1, 256, 14, 14]               0
          Conv2d-128          [-1, 256, 14, 14]         589,824
     BatchNorm2d-129          [-1, 256, 14, 14]             512
            ReLU-130          [-1, 256, 14, 14]               0
          Conv2d-131         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-132         [-1, 1024, 14, 14]           2,048
            ReLU-133         [-1, 1024, 14, 14]               0
      Bottleneck-134         [-1, 1024, 14, 14]               0
          Conv2d-135          [-1, 256, 14, 14]         262,144
     BatchNorm2d-136          [-1, 256, 14, 14]             512
            ReLU-137          [-1, 256, 14, 14]               0
          Conv2d-138          [-1, 256, 14, 14]         589,824
     BatchNorm2d-139          [-1, 256, 14, 14]             512
            ReLU-140          [-1, 256, 14, 14]               0
          Conv2d-141         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-142         [-1, 1024, 14, 14]           2,048
            ReLU-143         [-1, 1024, 14, 14]               0
      Bottleneck-144         [-1, 1024, 14, 14]               0
          Conv2d-145          [-1, 512, 14, 14]         524,288
     BatchNorm2d-146          [-1, 512, 14, 14]           1,024
            ReLU-147          [-1, 512, 14, 14]               0
          Conv2d-148            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-149            [-1, 512, 7, 7]           1,024
            ReLU-150            [-1, 512, 7, 7]               0
          Conv2d-151           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-152           [-1, 2048, 7, 7]           4,096
          Conv2d-153           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-154           [-1, 2048, 7, 7]           4,096
            ReLU-155           [-1, 2048, 7, 7]               0
      Bottleneck-156           [-1, 2048, 7, 7]               0
          Conv2d-157            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-158            [-1, 512, 7, 7]           1,024
            ReLU-159            [-1, 512, 7, 7]               0
          Conv2d-160            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-161            [-1, 512, 7, 7]           1,024
            ReLU-162            [-1, 512, 7, 7]               0
          Conv2d-163           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-164           [-1, 2048, 7, 7]           4,096
            ReLU-165           [-1, 2048, 7, 7]               0
      Bottleneck-166           [-1, 2048, 7, 7]               0
          Conv2d-167            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-168            [-1, 512, 7, 7]           1,024
            ReLU-169            [-1, 512, 7, 7]               0
          Conv2d-170            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-171            [-1, 512, 7, 7]           1,024
            ReLU-172            [-1, 512, 7, 7]               0
          Conv2d-173           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-174           [-1, 2048, 7, 7]           4,096
            ReLU-175           [-1, 2048, 7, 7]               0
      Bottleneck-176           [-1, 2048, 7, 7]               0
          Conv2d-177           [-1, 1024, 7, 7]       2,098,176
            ReLU-178           [-1, 1024, 7, 7]               0
        Upsample-179         [-1, 1024, 14, 14]               0
          Conv2d-180          [-1, 512, 14, 14]         524,800
            ReLU-181          [-1, 512, 14, 14]               0
          Conv2d-182          [-1, 512, 14, 14]       7,078,400
            ReLU-183          [-1, 512, 14, 14]               0
        Upsample-184          [-1, 512, 28, 28]               0
          Conv2d-185          [-1, 512, 28, 28]         262,656
            ReLU-186          [-1, 512, 28, 28]               0
          Conv2d-187          [-1, 512, 28, 28]       4,719,104
            ReLU-188          [-1, 512, 28, 28]               0
        Upsample-189          [-1, 512, 56, 56]               0
          Conv2d-190          [-1, 256, 56, 56]          65,792
            ReLU-191          [-1, 256, 56, 56]               0
          Conv2d-192          [-1, 256, 56, 56]       1,769,728
            ReLU-193          [-1, 256, 56, 56]               0
        Upsample-194        [-1, 256, 112, 112]               0
          Conv2d-195         [-1, 64, 112, 112]           4,160
            ReLU-196         [-1, 64, 112, 112]               0
          Conv2d-197        [-1, 128, 112, 112]         368,768
            ReLU-198        [-1, 128, 112, 112]               0
        Upsample-199        [-1, 128, 224, 224]               0
          Conv2d-200         [-1, 64, 224, 224]         110,656
            ReLU-201         [-1, 64, 224, 224]               0
          Conv2d-202          [-1, 6, 224, 224]             390
================================================================
Total params: 40,549,382
Trainable params: 40,549,382
Non-trainable params: 0
----------------------------------------------------------------

Define the main training loop

from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

Training

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 6
model = ResNetUNet(num_class).to(device)

# freeze backbone layers
#for l in model.base_layers:
#    for param in l.parameters():
#        param.requires_grad = False

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)

model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=60)
cuda:0
Epoch 0/59
----------
LR 0.0001
train: bce: 0.070256, dice: 0.856320, loss: 0.463288
val: bce: 0.014897, dice: 0.515814, loss: 0.265356
saving best model
0m 51s
Epoch 1/59
----------
LR 0.0001
train: bce: 0.011369, dice: 0.309445, loss: 0.160407
val: bce: 0.003790, dice: 0.113682, loss: 0.058736
saving best model
0m 51s
Epoch 2/59
----------
LR 0.0001
train: bce: 0.003480, dice: 0.089928, loss: 0.046704
val: bce: 0.002525, dice: 0.067604, loss: 0.035064
saving best model
0m 51s

(Omitted)

Epoch 57/59
----------
LR 1e-05
train: bce: 0.000523, dice: 0.010289, loss: 0.005406
val: bce: 0.001558, dice: 0.030965, loss: 0.016261
0m 51s
Epoch 58/59
----------
LR 1e-05
train: bce: 0.000518, dice: 0.010209, loss: 0.005364
val: bce: 0.001548, dice: 0.031034, loss: 0.016291
0m 51s
Epoch 59/59
----------
LR 1e-05
train: bce: 0.000518, dice: 0.010168, loss: 0.005343
val: bce: 0.001566, dice: 0.030785, loss: 0.016176
0m 50s
Best val loss: 0.016171

Use the trained model

import math

model.eval()   # Set model to the evaluation mode

# Create another simulation dataset for test
test_dataset = SimDataset(3, transform = trans)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)

# Get the first batch
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)

# Predict
pred = model(inputs)
# The loss functions include the sigmoid function.
pred = F.sigmoid(pred)
pred = pred.data.cpu().numpy()
print(pred.shape)

# Change channel-order and make 3 channels for matplot
input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]
pred_rgb = [helper.masks_to_colorimg(x) for x in pred]

helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])
(3, 6, 192, 192)

Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask

png

About

Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 96.4%
  • Python 3.6%