bnsapa commited on
Commit
1455b85
·
1 Parent(s): 6aec5cf

Fix bug in login functionality

Browse files
Files changed (1) hide show
  1. TwinLite.py +468 -0
TwinLite.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ from torch.nn import Module, Conv2d, Parameter, Softmax
6
+
7
+ class PAM_Module(Module):
8
+ """ Position attention module"""
9
+ #Ref from SAGAN
10
+ def __init__(self, in_dim):
11
+ super(PAM_Module, self).__init__()
12
+ self.chanel_in = in_dim
13
+
14
+ self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
15
+ self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
16
+ self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
17
+ self.gamma = Parameter(torch.zeros(1))
18
+
19
+ self.softmax = Softmax(dim=-1)
20
+ def forward(self, x):
21
+ """
22
+ inputs :
23
+ x : input feature maps( B X C X H X W)
24
+ returns :
25
+ out : attention value + input feature
26
+ attention: B X (HxW) X (HxW)
27
+ """
28
+ m_batchsize, C, height, width = x.size()
29
+ proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
30
+ proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
31
+ energy = torch.bmm(proj_query, proj_key)
32
+ attention = self.softmax(energy)
33
+ proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
34
+
35
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
36
+ out = out.view(m_batchsize, C, height, width)
37
+
38
+ out = self.gamma*out + x
39
+ return out
40
+ class CAM_Module(Module):
41
+ """ Channel attention module"""
42
+ def __init__(self, in_dim):
43
+ super(CAM_Module, self).__init__()
44
+ self.chanel_in = in_dim
45
+
46
+
47
+ self.gamma = Parameter(torch.zeros(1))
48
+ self.softmax = Softmax(dim=-1)
49
+ def forward(self,x):
50
+ """
51
+ inputs :
52
+ x : input feature maps( B X C X H X W)
53
+ returns :
54
+ out : attention value + input feature
55
+ attention: B X C X C
56
+ """
57
+ m_batchsize, C, height, width = x.size()
58
+ proj_query = x.view(m_batchsize, C, -1)
59
+ proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
60
+ energy = torch.bmm(proj_query, proj_key)
61
+ energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
62
+ attention = self.softmax(energy_new)
63
+ proj_value = x.view(m_batchsize, C, -1)
64
+
65
+ out = torch.bmm(attention, proj_value)
66
+ out = out.view(m_batchsize, C, height, width)
67
+
68
+ out = self.gamma*out + x
69
+ return out
70
+
71
+
72
+ class UPx2(nn.Module):
73
+ '''
74
+ This class defines the convolution layer with batch normalization and PReLU activation
75
+ '''
76
+ def __init__(self, nIn, nOut):
77
+ '''
78
+
79
+ :param nIn: number of input channels
80
+ :param nOut: number of output channels
81
+ :param kSize: kernel size
82
+ :param stride: stride rate for down-sampling. Default is 1
83
+ '''
84
+ super().__init__()
85
+ self.deconv = nn.ConvTranspose2d(nIn, nOut, 2, stride=2, padding=0, output_padding=0, bias=False)
86
+ self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
87
+ self.act = nn.PReLU(nOut)
88
+
89
+ def forward(self, input):
90
+ '''
91
+ :param input: input feature map
92
+ :return: transformed feature map
93
+ '''
94
+ output = self.deconv(input)
95
+ output = self.bn(output)
96
+ output = self.act(output)
97
+ return output
98
+ def fuseforward(self, input):
99
+ output = self.deconv(input)
100
+ output = self.act(output)
101
+ return output
102
+
103
+ class CBR(nn.Module):
104
+ '''
105
+ This class defines the convolution layer with batch normalization and PReLU activation
106
+ '''
107
+ def __init__(self, nIn, nOut, kSize, stride=1):
108
+ '''
109
+
110
+ :param nIn: number of input channels
111
+ :param nOut: number of output channels
112
+ :param kSize: kernel size
113
+ :param stride: stride rate for down-sampling. Default is 1
114
+ '''
115
+ super().__init__()
116
+ padding = int((kSize - 1)/2)
117
+ #self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False)
118
+ self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
119
+ #self.conv1 = nn.Conv2d(nOut, nOut, (1, kSize), stride=1, padding=(0, padding), bias=False)
120
+ self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
121
+ self.act = nn.PReLU(nOut)
122
+
123
+ def forward(self, input):
124
+ '''
125
+ :param input: input feature map
126
+ :return: transformed feature map
127
+ '''
128
+ output = self.conv(input)
129
+ #output = self.conv1(output)
130
+ output = self.bn(output)
131
+ output = self.act(output)
132
+ return output
133
+ def fuseforward(self, input):
134
+ output = self.conv(input)
135
+ output = self.act(output)
136
+ return output
137
+
138
+
139
+
140
+
141
+
142
+ class CB(nn.Module):
143
+ '''
144
+ This class groups the convolution and batch normalization
145
+ '''
146
+ def __init__(self, nIn, nOut, kSize, stride=1):
147
+ '''
148
+ :param nIn: number of input channels
149
+ :param nOut: number of output channels
150
+ :param kSize: kernel size
151
+ :param stride: optinal stide for down-sampling
152
+ '''
153
+ super().__init__()
154
+ padding = int((kSize - 1)/2)
155
+ self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
156
+ self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
157
+
158
+ def forward(self, input):
159
+ '''
160
+
161
+ :param input: input feature map
162
+ :return: transformed feature map
163
+ '''
164
+ output = self.conv(input)
165
+ output = self.bn(output)
166
+ return output
167
+
168
+ class C(nn.Module):
169
+ '''
170
+ This class is for a convolutional layer.
171
+ '''
172
+ def __init__(self, nIn, nOut, kSize, stride=1):
173
+ '''
174
+
175
+ :param nIn: number of input channels
176
+ :param nOut: number of output channels
177
+ :param kSize: kernel size
178
+ :param stride: optional stride rate for down-sampling
179
+ '''
180
+ super().__init__()
181
+ padding = int((kSize - 1)/2)
182
+ # print(nIn, nOut, (kSize, kSize))
183
+ self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False)
184
+
185
+ def forward(self, input):
186
+ '''
187
+ :param input: input feature map
188
+ :return: transformed feature map
189
+ '''
190
+ output = self.conv(input)
191
+ return output
192
+
193
+ class CDilated(nn.Module):
194
+ '''
195
+ This class defines the dilated convolution.
196
+ '''
197
+ def __init__(self, nIn, nOut, kSize, stride=1, d=1):
198
+ '''
199
+ :param nIn: number of input channels
200
+ :param nOut: number of output channels
201
+ :param kSize: kernel size
202
+ :param stride: optional stride rate for down-sampling
203
+ :param d: optional dilation rate
204
+ '''
205
+ super().__init__()
206
+ padding = int((kSize - 1)/2) * d
207
+ self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False, dilation=d)
208
+
209
+ def forward(self, input):
210
+ '''
211
+ :param input: input feature map
212
+ :return: transformed feature map
213
+ '''
214
+ output = self.conv(input)
215
+ return output
216
+
217
+ class DownSamplerB(nn.Module):
218
+ def __init__(self, nIn, nOut):
219
+ super().__init__()
220
+ n = int(nOut/5)
221
+ n1 = nOut - 4*n
222
+ self.c1 = C(nIn, n, 3, 2)
223
+ self.d1 = CDilated(n, n1, 3, 1, 1)
224
+ self.d2 = CDilated(n, n, 3, 1, 2)
225
+ self.d4 = CDilated(n, n, 3, 1, 4)
226
+ self.d8 = CDilated(n, n, 3, 1, 8)
227
+ self.d16 = CDilated(n, n, 3, 1, 16)
228
+ self.bn = nn.BatchNorm2d(nOut, eps=1e-3)
229
+ self.act = nn.PReLU(nOut)
230
+
231
+ def forward(self, input):
232
+ output1 = self.c1(input)
233
+ d1 = self.d1(output1)
234
+ d2 = self.d2(output1)
235
+ d4 = self.d4(output1)
236
+ d8 = self.d8(output1)
237
+ d16 = self.d16(output1)
238
+
239
+ add1 = d2
240
+ add2 = add1 + d4
241
+ add3 = add2 + d8
242
+ add4 = add3 + d16
243
+
244
+ combine = torch.cat([d1, add1, add2, add3, add4],1)
245
+ #combine_in_out = input + combine
246
+ output = self.bn(combine)
247
+ output = self.act(output)
248
+ return output
249
+ class BR(nn.Module):
250
+ '''
251
+ This class groups the batch normalization and PReLU activation
252
+ '''
253
+ def __init__(self, nOut):
254
+ '''
255
+ :param nOut: output feature maps
256
+ '''
257
+ super().__init__()
258
+ self.nOut=nOut
259
+ self.bn = nn.BatchNorm2d(nOut, eps=1e-03)
260
+ self.act = nn.PReLU(nOut)
261
+
262
+ def forward(self, input):
263
+ '''
264
+ :param input: input feature map
265
+ :return: normalized and thresholded feature map
266
+ '''
267
+ # print("bf bn :",input.size(),self.nOut)
268
+ output = self.bn(input)
269
+ # print("after bn :",output.size())
270
+ output = self.act(output)
271
+ # print("after act :",output.size())
272
+ return output
273
+ class DilatedParllelResidualBlockB(nn.Module):
274
+ '''
275
+ This class defines the ESP block, which is based on the following principle
276
+ Reduce ---> Split ---> Transform --> Merge
277
+ '''
278
+ def __init__(self, nIn, nOut, add=True):
279
+ '''
280
+ :param nIn: number of input channels
281
+ :param nOut: number of output channels
282
+ :param add: if true, add a residual connection through identity operation. You can use projection too as
283
+ in ResNet paper, but we avoid to use it if the dimensions are not the same because we do not want to
284
+ increase the module complexity
285
+ '''
286
+ super().__init__()
287
+ n = max(int(nOut/5),1)
288
+ n1 = max(nOut - 4*n,1)
289
+ # print(nIn,n,n1,"--")
290
+ self.c1 = C(nIn, n, 1, 1)
291
+ self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0
292
+ self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1
293
+ self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2
294
+ self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3
295
+ self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4
296
+ # print("nOut bf :",nOut)
297
+ self.bn = BR(nOut)
298
+ # print("nOut at :",self.bn.size())
299
+ self.add = add
300
+
301
+ def forward(self, input):
302
+ '''
303
+ :param input: input feature map
304
+ :return: transformed feature map
305
+ '''
306
+ # reduce
307
+ output1 = self.c1(input)
308
+ # split and transform
309
+ d1 = self.d1(output1)
310
+ d2 = self.d2(output1)
311
+ d4 = self.d4(output1)
312
+ d8 = self.d8(output1)
313
+ d16 = self.d16(output1)
314
+
315
+
316
+ # heirarchical fusion for de-gridding
317
+ add1 = d2
318
+ add2 = add1 + d4
319
+ add3 = add2 + d8
320
+ add4 = add3 + d16
321
+ # print(d1.size(),add1.size(),add2.size(),add3.size(),add4.size())
322
+
323
+ #merge
324
+ combine = torch.cat([d1, add1, add2, add3, add4], 1)
325
+ # print("combine :",combine.size())
326
+ # if residual version
327
+ if self.add:
328
+ # print("add :",combine.size())
329
+ combine = input + combine
330
+ # print(combine.size(),"-----------------")
331
+ output = self.bn(combine)
332
+ return output
333
+
334
+ class InputProjectionA(nn.Module):
335
+ '''
336
+ This class projects the input image to the same spatial dimensions as the feature map.
337
+ For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then
338
+ this class will generate an output of 56x56x3
339
+ '''
340
+ def __init__(self, samplingTimes):
341
+ '''
342
+ :param samplingTimes: The rate at which you want to down-sample the image
343
+ '''
344
+ super().__init__()
345
+ self.pool = nn.ModuleList()
346
+ for i in range(0, samplingTimes):
347
+ #pyramid-based approach for down-sampling
348
+ self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
349
+
350
+ def forward(self, input):
351
+ '''
352
+ :param input: Input RGB Image
353
+ :return: down-sampled image (pyramid-based approach)
354
+ '''
355
+ for pool in self.pool:
356
+ input = pool(input)
357
+ return input
358
+
359
+ class ESPNet_Encoder(nn.Module):
360
+ '''
361
+ This class defines the ESPNet-C network in the paper
362
+ '''
363
+ def __init__(self, p=5, q=3):
364
+ # def __init__(self, classes=20, p=1, q=1):
365
+ '''
366
+ :param classes: number of classes in the dataset. Default is 20 for the cityscapes
367
+ :param p: depth multiplier
368
+ :param q: depth multiplier
369
+ '''
370
+ super().__init__()
371
+ self.level1 = CBR(3, 16, 3, 2)
372
+ self.sample1 = InputProjectionA(1)
373
+ self.sample2 = InputProjectionA(2)
374
+
375
+ self.b1 = CBR(16 + 3,19,3)
376
+ self.level2_0 = DownSamplerB(16 +3, 64)
377
+
378
+ self.level2 = nn.ModuleList()
379
+ for i in range(0, p):
380
+ self.level2.append(DilatedParllelResidualBlockB(64 , 64))
381
+ self.b2 = CBR(128 + 3,131,3)
382
+
383
+ self.level3_0 = DownSamplerB(128 + 3, 128)
384
+ self.level3 = nn.ModuleList()
385
+ for i in range(0, q):
386
+ self.level3.append(DilatedParllelResidualBlockB(128 , 128))
387
+ # self.mixstyle = MixStyle2(p=0.5, alpha=0.1)
388
+ self.b3 = CBR(256,32,3)
389
+ self.sa = PAM_Module(32)
390
+ self.sc = CAM_Module(32)
391
+ self.conv_sa = CBR(32,32,3)
392
+ self.conv_sc = CBR(32,32,3)
393
+ self.classifier = CBR(32, 32, 1, 1)
394
+
395
+ def forward(self, input):
396
+ '''
397
+ :param input: Receives the input RGB image
398
+ :return: the transformed feature map with spatial dimensions 1/8th of the input image
399
+ '''
400
+ output0 = self.level1(input)
401
+ inp1 = self.sample1(input)
402
+ inp2 = self.sample2(input)
403
+
404
+ output0_cat = self.b1(torch.cat([output0, inp1], 1))
405
+ output1_0 = self.level2_0(output0_cat) # down-sampled
406
+
407
+ for i, layer in enumerate(self.level2):
408
+ if i==0:
409
+ output1 = layer(output1_0)
410
+ else:
411
+ output1 = layer(output1)
412
+
413
+ output1_cat = self.b2(torch.cat([output1, output1_0, inp2], 1))
414
+ output2_0 = self.level3_0(output1_cat) # down-sampled
415
+ for i, layer in enumerate(self.level3):
416
+ if i==0:
417
+ output2 = layer(output2_0)
418
+ else:
419
+ output2 = layer(output2)
420
+ cat_=torch.cat([output2_0, output2], 1)
421
+
422
+ output2_cat = self.b3(cat_)
423
+ out_sa=self.sa(output2_cat)
424
+ out_sa=self.conv_sa(out_sa)
425
+ out_sc=self.sc(output2_cat)
426
+ out_sc=self.conv_sc(out_sc)
427
+ out_s=out_sa+out_sc
428
+ classifier = self.classifier(out_s)
429
+
430
+ return classifier
431
+
432
+ class TwinLiteNet(nn.Module):
433
+ '''
434
+ This class defines the ESPNet network
435
+ '''
436
+
437
+ def __init__(self, p=2, q=3, ):
438
+
439
+ super().__init__()
440
+ self.encoder = ESPNet_Encoder(p, q)
441
+
442
+ self.up_1_1 = UPx2(32,16)
443
+ self.up_2_1 = UPx2(16,8)
444
+
445
+ self.up_1_2 = UPx2(32,16)
446
+ self.up_2_2 = UPx2(16,8)
447
+
448
+ self.classifier_1 = UPx2(8,2)
449
+ self.classifier_2 = UPx2(8,2)
450
+
451
+
452
+
453
+ def forward(self, input):
454
+
455
+ x=self.encoder(input)
456
+ x1=self.up_1_1(x)
457
+ x1=self.up_2_1(x1)
458
+ classifier1=self.classifier_1(x1)
459
+
460
+
461
+
462
+ x2=self.up_1_2(x)
463
+ x2=self.up_2_2(x2)
464
+ classifier2=self.classifier_2(x2)
465
+
466
+ return (classifier1,classifier2)
467
+
468
+