Skip to content

Commit daa7292

Browse files
Merge pull request #9 from piEsposito/main
Let the models return prediction only, saving KL Divergence as an attribute
2 parents 7abcfe7 + f1fc4e5 commit daa7292

File tree

6 files changed

+121
-32
lines changed

6 files changed

+121
-32
lines changed

bayesian_torch/layers/flipout_layers/conv_flipout.py

+42-12
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def __init__(self,
100100
self.posterior_rho_init = posterior_rho_init
101101
self.bias = bias
102102

103+
self.kl = 0
104+
103105
self.mu_kernel = nn.Parameter(
104106
torch.Tensor(out_channels, in_channels // groups, kernel_size))
105107
self.rho_kernel = nn.Parameter(
@@ -150,7 +152,7 @@ def init_parameters(self):
150152
self.prior_bias_mu.data.fill_(self.prior_mean)
151153
self.prior_bias_sigma.data.fill_(self.prior_variance)
152154

153-
def forward(self, x):
155+
def forward(self, x, return_kl=True):
154156

155157
# linear outputs
156158
outputs = F.conv1d(x,
@@ -191,8 +193,11 @@ def forward(self, x):
191193
dilation=self.dilation,
192194
groups=self.groups) * sign_output
193195

196+
self.kl = kl
194197
# returning outputs + perturbations
195-
return outputs + perturbed_outputs, kl
198+
if return_kl:
199+
return outputs + perturbed_outputs, kl
200+
return outputs + perturbed_outputs
196201

197202

198203
class Conv2dFlipout(BaseVariationalLayer_):
@@ -244,6 +249,8 @@ def __init__(self,
244249
self.posterior_rho_init = posterior_rho_init
245250
self.bias = bias
246251

252+
self.kl = 0
253+
247254
self.mu_kernel = nn.Parameter(
248255
torch.Tensor(out_channels, in_channels // groups, kernel_size,
249256
kernel_size))
@@ -299,7 +306,7 @@ def init_parameters(self):
299306
self.prior_bias_mu.data.fill_(self.prior_mean)
300307
self.prior_bias_sigma.data.fill_(self.prior_variance)
301308

302-
def forward(self, x):
309+
def forward(self, x, return_kl=True):
303310

304311
# linear outputs
305312
outputs = F.conv2d(x,
@@ -340,8 +347,11 @@ def forward(self, x):
340347
dilation=self.dilation,
341348
groups=self.groups) * sign_output
342349

350+
self.kl = kl
343351
# returning outputs + perturbations
344-
return outputs + perturbed_outputs, kl
352+
if return_kl:
353+
return outputs + perturbed_outputs, kl
354+
return outputs + perturbed_outputs
345355

346356

347357
class Conv3dFlipout(BaseVariationalLayer_):
@@ -388,6 +398,8 @@ def __init__(self,
388398
self.groups = groups
389399
self.bias = bias
390400

401+
self.kl = 0
402+
391403
self.prior_mean = prior_mean
392404
self.prior_variance = prior_variance
393405
self.posterior_mu_init = posterior_mu_init
@@ -448,7 +460,7 @@ def init_parameters(self):
448460
self.prior_bias_mu.data.fill_(self.prior_mean)
449461
self.prior_bias_sigma.data.fill_(self.prior_variance)
450462

451-
def forward(self, x):
463+
def forward(self, x, return_kl=True):
452464

453465
# linear outputs
454466
outputs = F.conv3d(x,
@@ -489,8 +501,11 @@ def forward(self, x):
489501
dilation=self.dilation,
490502
groups=self.groups) * sign_output
491503

504+
self.kl = kl
492505
# returning outputs + perturbations
493-
return outputs + perturbed_outputs, kl
506+
if return_kl:
507+
return outputs + perturbed_outputs, kl
508+
return outputs + perturbed_outputs
494509

495510

496511
class ConvTranspose1dFlipout(BaseVariationalLayer_):
@@ -537,6 +552,8 @@ def __init__(self,
537552
self.groups = groups
538553
self.bias = bias
539554

555+
self.kl = 0
556+
540557
self.prior_mean = prior_mean
541558
self.prior_variance = prior_variance
542559
self.posterior_mu_init = posterior_mu_init
@@ -593,7 +610,7 @@ def init_parameters(self):
593610
self.prior_bias_mu.data.fill_(self.prior_mean)
594611
self.prior_bias_sigma.data.fill_(self.prior_variance)
595612

596-
def forward(self, x):
613+
def forward(self, x, return_kl=True):
597614

598615
# linear outputs
599616
outputs = F.conv_transpose1d(x,
@@ -635,8 +652,11 @@ def forward(self, x):
635652
dilation=self.dilation,
636653
groups=self.groups) * sign_output
637654

655+
self.kl = kl
638656
# returning outputs + perturbations
639-
return outputs + perturbed_outputs, kl
657+
if return_kl:
658+
return outputs + perturbed_outputs, kl
659+
return outputs + perturbed_outputs
640660

641661

642662
class ConvTranspose2dFlipout(BaseVariationalLayer_):
@@ -683,6 +703,8 @@ def __init__(self,
683703
self.groups = groups
684704
self.bias = bias
685705

706+
self.kl = 0
707+
686708
self.prior_mean = prior_mean
687709
self.prior_variance = prior_variance
688710
self.posterior_mu_init = posterior_mu_init
@@ -743,7 +765,7 @@ def init_parameters(self):
743765
self.prior_bias_mu.data.fill_(self.prior_mean)
744766
self.prior_bias_sigma.data.fill_(self.prior_variance)
745767

746-
def forward(self, x):
768+
def forward(self, x, return_kl=True):
747769

748770
# linear outputs
749771
outputs = F.conv_transpose2d(x,
@@ -785,8 +807,11 @@ def forward(self, x):
785807
dilation=self.dilation,
786808
groups=self.groups) * sign_output
787809

810+
self.kl = kl
788811
# returning outputs + perturbations
789-
return outputs + perturbed_outputs, kl
812+
if return_kl:
813+
return outputs + perturbed_outputs, kl
814+
return outputs + perturbed_outputs
790815

791816

792817
class ConvTranspose3dFlipout(BaseVariationalLayer_):
@@ -838,6 +863,8 @@ def __init__(self,
838863
self.posterior_rho_init = posterior_rho_init
839864
self.bias = bias
840865

866+
self.kl = 0
867+
841868
self.mu_kernel = nn.Parameter(
842869
torch.Tensor(in_channels, out_channels // groups, kernel_size,
843870
kernel_size, kernel_size))
@@ -893,7 +920,7 @@ def init_parameters(self):
893920
self.prior_bias_mu.data.fill_(self.prior_mean)
894921
self.prior_bias_sigma.data.fill_(self.prior_variance)
895922

896-
def forward(self, x):
923+
def forward(self, x, return_kl=True):
897924

898925
# linear outputs
899926
outputs = F.conv_transpose3d(x,
@@ -935,5 +962,8 @@ def forward(self, x):
935962
dilation=self.dilation,
936963
groups=self.groups) * sign_output
937964

965+
self.kl = kl
938966
# returning outputs + perturbations
939-
return outputs + perturbed_outputs, kl
967+
if return_kl:
968+
return outputs + perturbed_outputs, kl
969+
return outputs + perturbed_outputs

bayesian_torch/layers/flipout_layers/linear_flipout.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def __init__(self,
9090
torch.Tensor(out_features, in_features),
9191
persistent=False)
9292

93+
self.kl = 0
94+
9395
if bias:
9496
self.mu_bias = nn.Parameter(torch.Tensor(out_features))
9597
self.rho_bias = nn.Parameter(torch.Tensor(out_features))
@@ -123,7 +125,7 @@ def init_parameters(self):
123125
self.mu_bias.data.normal_(mean=self.posterior_mu_init, std=0.1)
124126
self.rho_bias.data.normal_(mean=self.posterior_rho_init, std=0.1)
125127

126-
def forward(self, x):
128+
def forward(self, x, return_kl=True):
127129
# sampling delta_W
128130
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
129131
delta_weight = (sigma_weight * self.eps_weight.data.normal_())
@@ -148,5 +150,9 @@ def forward(self, x):
148150
perturbed_outputs = F.linear(x * sign_input, delta_weight,
149151
bias) * sign_output
150152

153+
self.kl = kl
154+
151155
# returning outputs + perturbations
152-
return outputs + perturbed_outputs, kl
156+
if return_kl:
157+
return outputs + perturbed_outputs, kl
158+
return outputs + perturbed_outputs

bayesian_torch/layers/flipout_layers/rnn_flipout.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def __init__(self,
7676
self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
7777
self.bias = bias
7878

79+
self.kl = 0
80+
7981
self.ih = LinearFlipout(prior_mean=prior_mean,
8082
prior_variance=prior_variance,
8183
posterior_mu_init=posterior_mu_init,
@@ -92,7 +94,7 @@ def __init__(self,
9294
out_features=out_features * 4,
9395
bias=bias)
9496

95-
def forward(self, X, hidden_states=None):
97+
def forward(self, X, hidden_states=None, return_kl=True):
9698

9799
batch_size, seq_size, _ = X.size()
98100

@@ -137,4 +139,7 @@ def forward(self, X, hidden_states=None):
137139
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
138140
c_ts = c_ts.transpose(0, 1).contiguous()
139141

140-
return hidden_seq, (hidden_seq, c_ts), kl
142+
self.kl = kl
143+
if return_kl:
144+
return hidden_seq, (hidden_seq, c_ts), kl
145+
return hidden_seq, (hidden_seq, c_ts)

0 commit comments

Comments
 (0)