@@ -100,6 +100,8 @@ def __init__(self,
100
100
self .posterior_rho_init = posterior_rho_init
101
101
self .bias = bias
102
102
103
+ self .kl = 0
104
+
103
105
self .mu_kernel = nn .Parameter (
104
106
torch .Tensor (out_channels , in_channels // groups , kernel_size ))
105
107
self .rho_kernel = nn .Parameter (
@@ -150,7 +152,7 @@ def init_parameters(self):
150
152
self .prior_bias_mu .data .fill_ (self .prior_mean )
151
153
self .prior_bias_sigma .data .fill_ (self .prior_variance )
152
154
153
- def forward (self , x ):
155
+ def forward (self , x , return_kl = True ):
154
156
155
157
# linear outputs
156
158
outputs = F .conv1d (x ,
@@ -191,8 +193,11 @@ def forward(self, x):
191
193
dilation = self .dilation ,
192
194
groups = self .groups ) * sign_output
193
195
196
+ self .kl = kl
194
197
# returning outputs + perturbations
195
- return outputs + perturbed_outputs , kl
198
+ if return_kl :
199
+ return outputs + perturbed_outputs , kl
200
+ return outputs + perturbed_outputs
196
201
197
202
198
203
class Conv2dFlipout (BaseVariationalLayer_ ):
@@ -244,6 +249,8 @@ def __init__(self,
244
249
self .posterior_rho_init = posterior_rho_init
245
250
self .bias = bias
246
251
252
+ self .kl = 0
253
+
247
254
self .mu_kernel = nn .Parameter (
248
255
torch .Tensor (out_channels , in_channels // groups , kernel_size ,
249
256
kernel_size ))
@@ -299,7 +306,7 @@ def init_parameters(self):
299
306
self .prior_bias_mu .data .fill_ (self .prior_mean )
300
307
self .prior_bias_sigma .data .fill_ (self .prior_variance )
301
308
302
- def forward (self , x ):
309
+ def forward (self , x , return_kl = True ):
303
310
304
311
# linear outputs
305
312
outputs = F .conv2d (x ,
@@ -340,8 +347,11 @@ def forward(self, x):
340
347
dilation = self .dilation ,
341
348
groups = self .groups ) * sign_output
342
349
350
+ self .kl = kl
343
351
# returning outputs + perturbations
344
- return outputs + perturbed_outputs , kl
352
+ if return_kl :
353
+ return outputs + perturbed_outputs , kl
354
+ return outputs + perturbed_outputs
345
355
346
356
347
357
class Conv3dFlipout (BaseVariationalLayer_ ):
@@ -388,6 +398,8 @@ def __init__(self,
388
398
self .groups = groups
389
399
self .bias = bias
390
400
401
+ self .kl = 0
402
+
391
403
self .prior_mean = prior_mean
392
404
self .prior_variance = prior_variance
393
405
self .posterior_mu_init = posterior_mu_init
@@ -448,7 +460,7 @@ def init_parameters(self):
448
460
self .prior_bias_mu .data .fill_ (self .prior_mean )
449
461
self .prior_bias_sigma .data .fill_ (self .prior_variance )
450
462
451
- def forward (self , x ):
463
+ def forward (self , x , return_kl = True ):
452
464
453
465
# linear outputs
454
466
outputs = F .conv3d (x ,
@@ -489,8 +501,11 @@ def forward(self, x):
489
501
dilation = self .dilation ,
490
502
groups = self .groups ) * sign_output
491
503
504
+ self .kl = kl
492
505
# returning outputs + perturbations
493
- return outputs + perturbed_outputs , kl
506
+ if return_kl :
507
+ return outputs + perturbed_outputs , kl
508
+ return outputs + perturbed_outputs
494
509
495
510
496
511
class ConvTranspose1dFlipout (BaseVariationalLayer_ ):
@@ -537,6 +552,8 @@ def __init__(self,
537
552
self .groups = groups
538
553
self .bias = bias
539
554
555
+ self .kl = 0
556
+
540
557
self .prior_mean = prior_mean
541
558
self .prior_variance = prior_variance
542
559
self .posterior_mu_init = posterior_mu_init
@@ -593,7 +610,7 @@ def init_parameters(self):
593
610
self .prior_bias_mu .data .fill_ (self .prior_mean )
594
611
self .prior_bias_sigma .data .fill_ (self .prior_variance )
595
612
596
- def forward (self , x ):
613
+ def forward (self , x , return_kl = True ):
597
614
598
615
# linear outputs
599
616
outputs = F .conv_transpose1d (x ,
@@ -635,8 +652,11 @@ def forward(self, x):
635
652
dilation = self .dilation ,
636
653
groups = self .groups ) * sign_output
637
654
655
+ self .kl = kl
638
656
# returning outputs + perturbations
639
- return outputs + perturbed_outputs , kl
657
+ if return_kl :
658
+ return outputs + perturbed_outputs , kl
659
+ return outputs + perturbed_outputs
640
660
641
661
642
662
class ConvTranspose2dFlipout (BaseVariationalLayer_ ):
@@ -683,6 +703,8 @@ def __init__(self,
683
703
self .groups = groups
684
704
self .bias = bias
685
705
706
+ self .kl = 0
707
+
686
708
self .prior_mean = prior_mean
687
709
self .prior_variance = prior_variance
688
710
self .posterior_mu_init = posterior_mu_init
@@ -743,7 +765,7 @@ def init_parameters(self):
743
765
self .prior_bias_mu .data .fill_ (self .prior_mean )
744
766
self .prior_bias_sigma .data .fill_ (self .prior_variance )
745
767
746
- def forward (self , x ):
768
+ def forward (self , x , return_kl = True ):
747
769
748
770
# linear outputs
749
771
outputs = F .conv_transpose2d (x ,
@@ -785,8 +807,11 @@ def forward(self, x):
785
807
dilation = self .dilation ,
786
808
groups = self .groups ) * sign_output
787
809
810
+ self .kl = kl
788
811
# returning outputs + perturbations
789
- return outputs + perturbed_outputs , kl
812
+ if return_kl :
813
+ return outputs + perturbed_outputs , kl
814
+ return outputs + perturbed_outputs
790
815
791
816
792
817
class ConvTranspose3dFlipout (BaseVariationalLayer_ ):
@@ -838,6 +863,8 @@ def __init__(self,
838
863
self .posterior_rho_init = posterior_rho_init
839
864
self .bias = bias
840
865
866
+ self .kl = 0
867
+
841
868
self .mu_kernel = nn .Parameter (
842
869
torch .Tensor (in_channels , out_channels // groups , kernel_size ,
843
870
kernel_size , kernel_size ))
@@ -893,7 +920,7 @@ def init_parameters(self):
893
920
self .prior_bias_mu .data .fill_ (self .prior_mean )
894
921
self .prior_bias_sigma .data .fill_ (self .prior_variance )
895
922
896
- def forward (self , x ):
923
+ def forward (self , x , return_kl = True ):
897
924
898
925
# linear outputs
899
926
outputs = F .conv_transpose3d (x ,
@@ -935,5 +962,8 @@ def forward(self, x):
935
962
dilation = self .dilation ,
936
963
groups = self .groups ) * sign_output
937
964
965
+ self .kl = kl
938
966
# returning outputs + perturbations
939
- return outputs + perturbed_outputs , kl
967
+ if return_kl :
968
+ return outputs + perturbed_outputs , kl
969
+ return outputs + perturbed_outputs
0 commit comments