-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathModel.py
273 lines (235 loc) · 10.9 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from torch.nn import init
import torch.nn.functional as F
from cnn_utils import *
import math
import torch
#============================================
__author__ = "Sachin Mehta"
__license__ = "MIT"
__maintainer__ = "Sachin Mehta"
#============================================
class EESP(nn.Module):
'''
This class defines the EESP block, which is based on the following principle
REDUCE ---> SPLIT ---> TRANSFORM --> MERGE
'''
def __init__(self, nIn, nOut, stride=1, k=4, r_lim=7, down_method='esp'): #down_method --> ['avg' or 'esp']
'''
:param nIn: number of input channels
:param nOut: number of output channels
:param stride: factor by which we should skip (useful for down-sampling). If 2, then down-samples the feature map by 2
:param k: # of parallel branches
:param r_lim: A maximum value of receptive field allowed for EESP block
:param down_method: Downsample or not (equivalent to say stride is 2 or not)
'''
super().__init__()
self.stride = stride
n = int(nOut / k)
n1 = nOut - (k - 1) * n
assert down_method in ['avg', 'esp'], 'One of these is suppported (avg or esp)'
assert n == n1, "n(={}) and n1(={}) should be equal for Depth-wise Convolution ".format(n, n1)
self.proj_1x1 = CBR(nIn, n, 1, stride=1, groups=k)
# (For convenience) Mapping between dilation rate and receptive field for a 3x3 kernel
map_receptive_ksize = {3: 1, 5: 2, 7: 3, 9: 4, 11: 5, 13: 6, 15: 7, 17: 8}
self.k_sizes = list()
for i in range(k):
ksize = int(3 + 2 * i)
# After reaching the receptive field limit, fall back to the base kernel size of 3 with a dilation rate of 1
ksize = ksize if ksize <= r_lim else 3
self.k_sizes.append(ksize)
# sort (in ascending order) these kernel sizes based on their receptive field
# This enables us to ignore the kernels (3x3 in our case) with the same effective receptive field in hierarchical
# feature fusion because kernels with 3x3 receptive fields does not have gridding artifact.
self.k_sizes.sort()
self.spp_dw = nn.ModuleList()
for i in range(k):
d_rate = map_receptive_ksize[self.k_sizes[i]]
self.spp_dw.append(CDilated(n, n, kSize=3, stride=stride, groups=n, d=d_rate))
# Performing a group convolution with K groups is the same as performing K point-wise convolutions
self.conv_1x1_exp = CB(nOut, nOut, 1, 1, groups=k)
self.br_after_cat = BR(nOut)
self.module_act = nn.PReLU(nOut)
self.downAvg = True if down_method == 'avg' else False
def forward(self, input):
'''
:param input: input feature map
:return: transformed feature map
'''
# Reduce --> project high-dimensional feature maps to low-dimensional space
output1 = self.proj_1x1(input)
output = [self.spp_dw[0](output1)]
# compute the output for each branch and hierarchically fuse them
# i.e. Split --> Transform --> HFF
for k in range(1, len(self.spp_dw)):
out_k = self.spp_dw[k](output1)
# HFF
out_k = out_k + output[k - 1]
output.append(out_k)
# Merge
expanded = self.conv_1x1_exp( # learn linear combinations using group point-wise convolutions
self.br_after_cat( # apply batch normalization followed by activation function (PRelu in this case)
torch.cat(output, 1) # concatenate the output of different branches
)
)
del output
# if down-sampling, then return the concatenated vector
# because Downsampling function will combine it with avg. pooled feature map and then threshold it
if self.stride == 2 and self.downAvg:
return expanded
# if dimensions of input and concatenated vector are the same, add them (RESIDUAL LINK)
if expanded.size() == input.size():
expanded = expanded + input
# Threshold the feature map using activation function (PReLU in this case)
return self.module_act(expanded)
class DownSampler(nn.Module):
'''
Down-sampling fucntion that has three parallel branches: (1) avg pooling,
(2) EESP block with stride of 2 and (3) efficient long-range connection with the input.
The output feature maps of branches from (1) and (2) are concatenated and then additively fused with (3) to produce
the final output.
'''
def __init__(self, nin, nout, k=4, r_lim=9, reinf=True):
'''
:param nin: number of input channels
:param nout: number of output channels
:param k: # of parallel branches
:param r_lim: A maximum value of receptive field allowed for EESP block
:param reinf: Use long range shortcut connection with the input or not.
'''
super().__init__()
nout_new = nout - nin
self.eesp = EESP(nin, nout_new, stride=2, k=k, r_lim=r_lim, down_method='avg')
self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2)
if reinf:
self.inp_reinf = nn.Sequential(
CBR(config_inp_reinf, config_inp_reinf, 3, 1),
CB(config_inp_reinf, nout, 1, 1)
)
self.act = nn.PReLU(nout)
def forward(self, input, input2=None):
'''
:param input: input feature map
:return: feature map down-sampled by a factor of 2
'''
avg_out = self.avg(input)
eesp_out = self.eesp(input)
output = torch.cat([avg_out, eesp_out], 1)
if input2 is not None:
#assuming the input is a square image
# Shortcut connection with the input image
w1 = avg_out.size(2)
while True:
input2 = F.avg_pool2d(input2, kernel_size=3, padding=1, stride=2)
w2 = input2.size(2)
if w2 == w1:
break
output = output + self.inp_reinf(input2)
return self.act(output)
class EESPNet(nn.Module):
'''
This class defines the ESPNetv2 architecture for the ImageNet classification
'''
def __init__(self, classes=1000, s=1):
'''
:param classes: number of classes in the dataset. Default is 1000 for the ImageNet dataset
:param s: factor that scales the number of output feature maps
'''
super().__init__()
reps = [0, 3, 7, 3] # how many times EESP blocks should be repeated at each spatial level.
channels = 3
r_lim = [13, 11, 9, 7, 5] # receptive field at each spatial level
K = [4]*len(r_lim) # No. of parallel branches at different levels
base = 32 #base configuration
config_len = 5
config = [base] * config_len
base_s = 0
for i in range(config_len):
if i== 0:
base_s = int(base * s)
base_s = math.ceil(base_s / K[0]) * K[0]
config[i] = base if base_s > base else base_s
else:
config[i] = base_s * pow(2, i)
if s <= 1.5:
config.append(1024)
elif s <= 2.0:
config.append(1280)
else:
ValueError('Configuration not supported')
global config_inp_reinf
config_inp_reinf = 3
self.input_reinforcement = True # True for the shortcut connection with input
assert len(K) == len(r_lim), 'Length of branching factor array and receptive field array should be the same.'
self.level1 = CBR(channels, config[0], 3, 2) # 112 L1
self.level2_0 = DownSampler(config[0], config[1], k=K[0], r_lim=r_lim[0], reinf=self.input_reinforcement) # out = 56
self.level3_0 = DownSampler(config[1], config[2], k=K[1], r_lim=r_lim[1], reinf=self.input_reinforcement) # out = 28
self.level3 = nn.ModuleList()
for i in range(reps[1]):
self.level3.append(EESP(config[2], config[2], stride=1, k=K[2], r_lim=r_lim[2]))
self.level4_0 = DownSampler(config[2], config[3], k=K[2], r_lim=r_lim[2], reinf=self.input_reinforcement) #out = 14
self.level4 = nn.ModuleList()
for i in range(reps[2]):
self.level4.append(EESP(config[3], config[3], stride=1, k=K[3], r_lim=r_lim[3]))
self.level5_0 = DownSampler(config[3], config[4], k=K[3], r_lim=r_lim[3]) #7
self.level5 = nn.ModuleList()
for i in range(reps[3]):
self.level5.append(EESP(config[4], config[4], stride=1, k=K[4], r_lim=r_lim[4]))
# expand the feature maps using depth-wise convolution followed by group point-wise convolution
self.level5.append(CBR(config[4], config[4], 3, 1, groups=config[4]))
self.level5.append(CBR(config[4], config[5], 1, 1, groups=K[4]))
self.classifier = nn.Linear(config[5], classes)
self.init_params()
def init_params(self):
'''
Function to initialze the parameters
'''
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, input, p=0.2):
'''
:param input: Receives the input RGB image
:return: a C-dimensional vector, C=# of classes
'''
out_l1 = self.level1(input) # 112
if not self.input_reinforcement:
del input
input = None
out_l2 = self.level2_0(out_l1, input) # 56
out_l3_0 = self.level3_0(out_l2, input) # down-sample
for i, layer in enumerate(self.level3):
if i == 0:
out_l3 = layer(out_l3_0)
else:
out_l3 = layer(out_l3)
out_l4_0 = self.level4_0(out_l3, input) # down-sample
for i, layer in enumerate(self.level4):
if i == 0:
out_l4 = layer(out_l4_0)
else:
out_l4 = layer(out_l4)
out_l5_0 = self.level5_0(out_l4) # down-sample
for i, layer in enumerate(self.level5):
if i == 0:
out_l5 = layer(out_l5_0)
else:
out_l5 = layer(out_l5)
output_g = F.adaptive_avg_pool2d(out_l5, output_size=1)
output_g = F.dropout(output_g, p=p, training=self.training)
output_1x1 = output_g.view(output_g.size(0), -1)
return self.classifier(output_1x1)
if __name__ == '__main__':
input = torch.Tensor(1, 3, 224, 224).cuda()
model = EESPNet(classes=1000, s=1.0)
out = model(input)
print('Output size')
print(out.size())