-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathtrain_utils.py
130 lines (97 loc) · 3.96 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#============================================
__author__ = "Sachin Mehta"
__license__ = "MIT"
__maintainer__ = "Sachin Mehta"
#============================================
from IOUEval import iouEval
import time
import torch
import numpy as np
def poly_lr_scheduler(args, optimizer, epoch, power=0.9):
lr = round(args.lr * (1 - epoch / args.max_epochs) ** power, 8)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def val(args, val_loader, model, criterion):
'''
:param args: general arguments
:param val_loader: loaded for validation dataset
:param model: model
:param criterion: loss function
:return: average epoch loss, overall pixel-wise accuracy, per class accuracy, per class iu, and mIOU
'''
#switch to evaluation mode
model.eval()
iouEvalVal = iouEval(args.classes)
epoch_loss = []
total_batches = len(val_loader)
for i, (input, target) in enumerate(val_loader):
start_time = time.time()
if args.onGPU:
input = input.cuda(non_blocking=True) #torch.autograd.Variable(input, volatile=True)
target = target.cuda(non_blocking=True)#torch.autograd.Variable(target, volatile=True)
# run the mdoel
output1 = model(input)
# compute the loss
loss = criterion(output1, target)
epoch_loss.append(loss.item())
time_taken = time.time() - start_time
# compute the confusion matrix
iouEvalVal.addBatch(output1.max(1)[1].data, target.data)
print('[%d/%d] loss: %.3f time: %.2f' % (i, total_batches, loss.item(), time_taken))
average_epoch_loss_val = sum(epoch_loss) / len(epoch_loss)
overall_acc, per_class_acc, per_class_iu, mIOU = iouEvalVal.getMetric()
return average_epoch_loss_val, overall_acc, per_class_acc, per_class_iu, mIOU
def train(args, train_loader, model, criterion, optimizer, epoch):
'''
:param args: general arguments
:param train_loader: loaded for training dataset
:param model: model
:param criterion: loss function
:param optimizer: optimization algo, such as ADAM or SGD
:param epoch: epoch number
:return: average epoch loss, overall pixel-wise accuracy, per class accuracy, per class iu, and mIOU
'''
# switch to train mode
model.train()
iouEvalTrain = iouEval(args.classes)
epoch_loss = []
total_batches = len(train_loader)
for i, (input, target) in enumerate(train_loader):
start_time = time.time()
if args.onGPU:
input = input.cuda(non_blocking=True) #torch.autograd.Variable(input, volatile=True)
target = target.cuda(non_blocking=True)
#run the mdoel
output1, output2 = model(input)
#set the grad to zero
optimizer.zero_grad()
loss1 = criterion(output1, target)
loss2 = criterion(output2, target)
loss = loss1 + loss2
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss.append(loss.item())
time_taken = time.time() - start_time
#compute the confusion matrix
iouEvalTrain.addBatch(output1.max(1)[1].data, target.data)
print('[%d/%d] loss: %.3f time:%.2f' % (i, total_batches, loss.item(), time_taken))
average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)
overall_acc, per_class_acc, per_class_iu, mIOU = iouEvalTrain.getMetric()
return average_epoch_loss_train, overall_acc, per_class_acc, per_class_iu, mIOU
def save_checkpoint(state, filenameCheckpoint='checkpoint.pth.tar'):
'''
helper function to save the checkpoint
:param state: model state
:param filenameCheckpoint: where to save the checkpoint
:return: nothing
'''
torch.save(state, filenameCheckpoint)
def netParams(model):
'''
helper function to see total network parameters
:param model: model
:return: total network parameters
'''
return np.sum([np.prod(parameter.size()) for parameter in model.parameters()])