Skip to content

Commit 342ca39

Browse files
committedOct 25, 2023
Add quant prepare functions
1 parent 39b41a5 commit 342ca39

File tree

2 files changed

+291
-39
lines changed

2 files changed

+291
-39
lines changed
 

‎bayesian_torch/layers/flipout_layers/conv_flipout.py

+166-34
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,33 @@ def forward(self, x, return_kl=True):
210210
self.prior_bias_sigma)
211211

212212
# perturbed feedforward
213-
perturbed_outputs = F.conv1d(x * sign_input,
214-
bias=bias,
213+
x_tmp = x * sign_input
214+
perturbed_outputs_tmp = F.conv1d(x * sign_input,
215215
weight=delta_kernel,
216+
bias=bias,
216217
stride=self.stride,
217218
padding=self.padding,
218219
dilation=self.dilation,
219-
groups=self.groups) * sign_output
220+
groups=self.groups)
221+
perturbed_outputs = perturbed_outputs_tmp * sign_output
222+
out = outputs + perturbed_outputs
223+
224+
if self.quant_prepare:
225+
# quint8 quantstub
226+
x = self.quint_quant[0](x) # input
227+
outputs = self.quint_quant[1](outputs) # output
228+
sign_input = self.quint_quant[2](sign_input)
229+
sign_output = self.quint_quant[3](sign_output)
230+
x_tmp = self.quint_quant[4](x_tmp)
231+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
232+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
233+
out = self.quint_quant[7](out) # output
234+
235+
# qint8 quantstub
236+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
237+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
238+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
239+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
220240

221241
self.kl = kl
222242
# returning outputs + perturbations
@@ -513,6 +533,15 @@ def __init__(self,
513533
self.register_buffer('prior_bias_sigma', None, persistent=False)
514534

515535
self.init_parameters()
536+
self.quant_prepare=False
537+
538+
def prepare(self):
539+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
540+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
541+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
542+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
543+
self.dequant = torch.quantization.DeQuantStub()
544+
self.quant_prepare=True
516545

517546
def init_parameters(self):
518547
# prior values
@@ -575,13 +604,33 @@ def forward(self, x, return_kl=True):
575604
self.prior_bias_sigma)
576605

577606
# perturbed feedforward
578-
perturbed_outputs = F.conv3d(x * sign_input,
607+
x_tmp = x * sign_input
608+
perturbed_outputs_tmp = F.conv3d(x * sign_input,
579609
weight=delta_kernel,
580610
bias=bias,
581611
stride=self.stride,
582612
padding=self.padding,
583613
dilation=self.dilation,
584-
groups=self.groups) * sign_output
614+
groups=self.groups)
615+
perturbed_outputs = perturbed_outputs_tmp * sign_output
616+
out = outputs + perturbed_outputs
617+
618+
if self.quant_prepare:
619+
# quint8 quantstub
620+
x = self.quint_quant[0](x) # input
621+
outputs = self.quint_quant[1](outputs) # output
622+
sign_input = self.quint_quant[2](sign_input)
623+
sign_output = self.quint_quant[3](sign_output)
624+
x_tmp = self.quint_quant[4](x_tmp)
625+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
626+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
627+
out = self.quint_quant[7](out) # output
628+
629+
# qint8 quantstub
630+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
631+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
632+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
633+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
585634

586635
self.kl = kl
587636
# returning outputs + perturbations
@@ -677,12 +726,20 @@ def __init__(self,
677726
self.register_buffer('prior_bias_sigma', None, persistent=False)
678727

679728
self.init_parameters()
729+
self.quant_prepare=False
730+
731+
def prepare(self):
732+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
733+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
734+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
735+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
736+
self.dequant = torch.quantization.DeQuantStub()
737+
self.quant_prepare=True
680738

681739
def init_parameters(self):
682740
# prior values
683741
self.prior_weight_mu.data.fill_(self.prior_mean)
684-
self.prior_weight_sigma.data.fill_
685-
(self.prior_variance)
742+
self.prior_weight_sigma.data.fill_(self.prior_variance)
686743

687744
# init our weights for the deterministic and perturbated weights
688745
self.mu_kernel.data.normal_(mean=self.posterior_mu_init, std=.1)
@@ -741,15 +798,34 @@ def forward(self, x, return_kl=True):
741798
self.prior_bias_sigma)
742799

743800
# perturbed feedforward
744-
perturbed_outputs = F.conv_transpose1d(
745-
x * sign_input,
746-
weight=delta_kernel,
747-
bias=bias,
748-
stride=self.stride,
749-
padding=self.padding,
750-
output_padding=self.output_padding,
751-
dilation=self.dilation,
752-
groups=self.groups) * sign_output
801+
x_tmp = x * sign_input
802+
perturbed_outputs_tmp = F.conv_transpose1d(x * sign_input,
803+
weight=delta_kernel,
804+
bias=bias,
805+
stride=self.stride,
806+
padding=self.padding,
807+
output_padding=self.output_padding,
808+
dilation=self.dilation,
809+
groups=self.groups)
810+
perturbed_outputs = perturbed_outputs_tmp * sign_output
811+
out = outputs + perturbed_outputs
812+
813+
if self.quant_prepare:
814+
# quint8 quantstub
815+
x = self.quint_quant[0](x) # input
816+
outputs = self.quint_quant[1](outputs) # output
817+
sign_input = self.quint_quant[2](sign_input)
818+
sign_output = self.quint_quant[3](sign_output)
819+
x_tmp = self.quint_quant[4](x_tmp)
820+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
821+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
822+
out = self.quint_quant[7](out) # output
823+
824+
# qint8 quantstub
825+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
826+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
827+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
828+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
753829

754830
self.kl = kl
755831
# returning outputs + perturbations
@@ -850,6 +926,15 @@ def __init__(self,
850926
self.register_buffer('prior_bias_sigma', None, persistent=False)
851927

852928
self.init_parameters()
929+
self.quant_prepare=False
930+
931+
def prepare(self):
932+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
933+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
934+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
935+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
936+
self.dequant = torch.quantization.DeQuantStub()
937+
self.quant_prepare=True
853938

854939
def init_parameters(self):
855940
# prior values
@@ -913,15 +998,34 @@ def forward(self, x, return_kl=True):
913998
self.prior_bias_sigma)
914999

9151000
# perturbed feedforward
916-
perturbed_outputs = F.conv_transpose2d(
917-
x * sign_input,
918-
bias=bias,
919-
weight=delta_kernel,
920-
stride=self.stride,
921-
padding=self.padding,
922-
output_padding=self.output_padding,
923-
dilation=self.dilation,
924-
groups=self.groups) * sign_output
1001+
x_tmp = x * sign_input
1002+
perturbed_outputs_tmp = F.conv_transpose2d(x * sign_input,
1003+
weight=delta_kernel,
1004+
bias=bias,
1005+
stride=self.stride,
1006+
padding=self.padding,
1007+
output_padding=self.output_padding,
1008+
dilation=self.dilation,
1009+
groups=self.groups)
1010+
perturbed_outputs = perturbed_outputs_tmp * sign_output
1011+
out = outputs + perturbed_outputs
1012+
1013+
if self.quant_prepare:
1014+
# quint8 quantstub
1015+
x = self.quint_quant[0](x) # input
1016+
outputs = self.quint_quant[1](outputs) # output
1017+
sign_input = self.quint_quant[2](sign_input)
1018+
sign_output = self.quint_quant[3](sign_output)
1019+
x_tmp = self.quint_quant[4](x_tmp)
1020+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
1021+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
1022+
out = self.quint_quant[7](out) # output
1023+
1024+
# qint8 quantstub
1025+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
1026+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
1027+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
1028+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
9251029

9261030
self.kl = kl
9271031
# returning outputs + perturbations
@@ -1022,6 +1126,15 @@ def __init__(self,
10221126
self.register_buffer('prior_bias_sigma', None, persistent=False)
10231127

10241128
self.init_parameters()
1129+
self.quant_prepare=False
1130+
1131+
def prepare(self):
1132+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
1133+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(4)])
1134+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
1135+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(8)])
1136+
self.dequant = torch.quantization.DeQuantStub()
1137+
self.quant_prepare=True
10251138

10261139
def init_parameters(self):
10271140
# prior values
@@ -1084,15 +1197,34 @@ def forward(self, x, return_kl=True):
10841197
self.prior_bias_sigma)
10851198

10861199
# perturbed feedforward
1087-
perturbed_outputs = F.conv_transpose3d(
1088-
x * sign_input,
1089-
weight=delta_kernel,
1090-
bias=bias,
1091-
stride=self.stride,
1092-
padding=self.padding,
1093-
output_padding=self.output_padding,
1094-
dilation=self.dilation,
1095-
groups=self.groups) * sign_output
1200+
x_tmp = x * sign_input
1201+
perturbed_outputs_tmp = F.conv_transpose3d(x * sign_input,
1202+
weight=delta_kernel,
1203+
bias=bias,
1204+
stride=self.stride,
1205+
padding=self.padding,
1206+
output_padding=self.output_padding,
1207+
dilation=self.dilation,
1208+
groups=self.groups)
1209+
perturbed_outputs = perturbed_outputs_tmp * sign_output
1210+
out = outputs + perturbed_outputs
1211+
1212+
if self.quant_prepare:
1213+
# quint8 quantstub
1214+
x = self.quint_quant[0](x) # input
1215+
outputs = self.quint_quant[1](outputs) # output
1216+
sign_input = self.quint_quant[2](sign_input)
1217+
sign_output = self.quint_quant[3](sign_output)
1218+
x_tmp = self.quint_quant[4](x_tmp)
1219+
perturbed_outputs_tmp = self.quint_quant[5](perturbed_outputs_tmp) # output
1220+
perturbed_outputs = self.quint_quant[6](perturbed_outputs) # output
1221+
out = self.quint_quant[7](out) # output
1222+
1223+
# qint8 quantstub
1224+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
1225+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
1226+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
1227+
delta_kernel =self.qint_quant[3](delta_kernel) # multiply activation
10961228

10971229
self.kl = kl
10981230
# returning outputs + perturbations

‎bayesian_torch/layers/variational_layers/conv_variational.py

+125-5
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ def __init__(self,
147147
self.register_buffer('prior_bias_sigma', None, persistent=False)
148148

149149
self.init_parameters()
150+
self.quant_prepare=False
151+
152+
def prepare(self):
153+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
154+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)])
155+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
156+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
157+
self.dequant = torch.quantization.DeQuantStub()
158+
self.quant_prepare=True
150159

151160
def init_parameters(self):
152161
self.prior_weight_mu.data.fill_(self.prior_mean)
@@ -177,7 +186,9 @@ def forward(self, input, return_kl=True):
177186

178187
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
179188
eps_kernel = self.eps_kernel.data.normal_()
180-
weight = self.mu_kernel + (sigma_weight * eps_kernel)
189+
tmp_result = sigma_weight * eps_kernel
190+
weight = self.mu_kernel + tmp_result
191+
181192
if return_kl:
182193
kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
183194
self.prior_weight_mu, self.prior_weight_sigma)
@@ -193,6 +204,19 @@ def forward(self, input, return_kl=True):
193204

194205
out = F.conv1d(input, weight, bias, self.stride, self.padding,
195206
self.dilation, self.groups)
207+
208+
if self.quant_prepare:
209+
# quint8 quantstub
210+
input = self.quint_quant[0](input) # input
211+
out = self.quint_quant[1](out) # output
212+
213+
# qint8 quantstub
214+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
215+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
216+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
217+
tmp_result =self.qint_quant[3](tmp_result) # multiply activation
218+
weight = self.qint_quant[4](weight) # add activatation
219+
196220
if return_kl:
197221
if self.bias:
198222
kl = kl_weight + kl_bias
@@ -470,6 +494,15 @@ def __init__(self,
470494
self.register_buffer('prior_bias_sigma', None, persistent=False)
471495

472496
self.init_parameters()
497+
self.quant_prepare=False
498+
499+
def prepare(self):
500+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
501+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)])
502+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
503+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
504+
self.dequant = torch.quantization.DeQuantStub()
505+
self.quant_prepare=True
473506

474507
def init_parameters(self):
475508
self.prior_weight_mu.fill_(self.prior_mean)
@@ -500,7 +533,9 @@ def forward(self, input, return_kl=True):
500533

501534
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
502535
eps_kernel = self.eps_kernel.data.normal_()
503-
weight = self.mu_kernel + (sigma_weight * eps_kernel)
536+
tmp_result = sigma_weight * eps_kernel
537+
weight = self.mu_kernel + tmp_result
538+
504539
if return_kl:
505540
kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
506541
self.prior_weight_mu, self.prior_weight_sigma)
@@ -516,6 +551,19 @@ def forward(self, input, return_kl=True):
516551

517552
out = F.conv3d(input, weight, bias, self.stride, self.padding,
518553
self.dilation, self.groups)
554+
555+
if self.quant_prepare:
556+
# quint8 quantstub
557+
input = self.quint_quant[0](input) # input
558+
out = self.quint_quant[1](out) # output
559+
560+
# qint8 quantstub
561+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
562+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
563+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
564+
tmp_result =self.qint_quant[3](tmp_result) # multiply activation
565+
weight = self.qint_quant[4](weight) # add activatation
566+
519567
if return_kl:
520568
if self.bias:
521569
kl = kl_weight + kl_bias
@@ -614,6 +662,15 @@ def __init__(self,
614662
self.register_buffer('prior_bias_sigma', None, persistent=False)
615663

616664
self.init_parameters()
665+
self.quant_prepare=False
666+
667+
def prepare(self):
668+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
669+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)])
670+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
671+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
672+
self.dequant = torch.quantization.DeQuantStub()
673+
self.quant_prepare=True
617674

618675
def init_parameters(self):
619676
self.prior_weight_mu.fill_(self.prior_mean)
@@ -644,7 +701,9 @@ def forward(self, input, return_kl=True):
644701

645702
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
646703
eps_kernel = self.eps_kernel.data.normal_()
647-
weight = self.mu_kernel + (sigma_weight * eps_kernel)
704+
tmp_result = sigma_weight * eps_kernel
705+
weight = self.mu_kernel + tmp_result
706+
648707
if return_kl:
649708
kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
650709
self.prior_weight_mu, self.prior_weight_sigma)
@@ -661,6 +720,19 @@ def forward(self, input, return_kl=True):
661720
out = F.conv_transpose1d(input, weight, bias, self.stride,
662721
self.padding, self.output_padding,
663722
self.dilation, self.groups)
723+
724+
if self.quant_prepare:
725+
# quint8 quantstub
726+
input = self.quint_quant[0](input) # input
727+
out = self.quint_quant[1](out) # output
728+
729+
# qint8 quantstub
730+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
731+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
732+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
733+
tmp_result =self.qint_quant[3](tmp_result) # multiply activation
734+
weight = self.qint_quant[4](weight) # add activatation
735+
664736
if return_kl:
665737
if self.bias:
666738
kl = kl_weight + kl_bias
@@ -765,6 +837,15 @@ def __init__(self,
765837
self.register_buffer('prior_bias_sigma', None, persistent=False)
766838

767839
self.init_parameters()
840+
self.quant_prepare=False
841+
842+
def prepare(self):
843+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
844+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)])
845+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
846+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
847+
self.dequant = torch.quantization.DeQuantStub()
848+
self.quant_prepare=True
768849

769850
def init_parameters(self):
770851
self.prior_weight_mu.fill_(self.prior_mean)
@@ -795,7 +876,9 @@ def forward(self, input, return_kl=True):
795876

796877
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
797878
eps_kernel = self.eps_kernel.data.normal_()
798-
weight = self.mu_kernel + (sigma_weight * eps_kernel)
879+
tmp_result = sigma_weight * eps_kernel
880+
weight = self.mu_kernel + tmp_result
881+
799882
if return_kl:
800883
kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
801884
self.prior_weight_mu, self.prior_weight_sigma)
@@ -812,6 +895,19 @@ def forward(self, input, return_kl=True):
812895
out = F.conv_transpose2d(input, weight, bias, self.stride,
813896
self.padding, self.output_padding,
814897
self.dilation, self.groups)
898+
899+
if self.quant_prepare:
900+
# quint8 quantstub
901+
input = self.quint_quant[0](input) # input
902+
out = self.quint_quant[1](out) # output
903+
904+
# qint8 quantstub
905+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
906+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
907+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
908+
tmp_result =self.qint_quant[3](tmp_result) # multiply activation
909+
weight = self.qint_quant[4](weight) # add activatation
910+
815911
if return_kl:
816912
if self.bias:
817913
kl = kl_weight + kl_bias
@@ -917,6 +1013,15 @@ def __init__(self,
9171013
self.register_buffer('prior_bias_sigma', None, persistent=False)
9181014

9191015
self.init_parameters()
1016+
self.quant_prepare=False
1017+
1018+
def prepare(self):
1019+
self.qint_quant = nn.ModuleList([torch.quantization.QuantStub(
1020+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric), activation=MinMaxObserver.with_args(dtype=torch.qint8,qscheme=torch.per_tensor_symmetric))) for _ in range(5)])
1021+
self.quint_quant = nn.ModuleList([torch.quantization.QuantStub(
1022+
QConfig(weight=MinMaxObserver.with_args(dtype=torch.quint8), activation=MinMaxObserver.with_args(dtype=torch.quint8))) for _ in range(2)])
1023+
self.dequant = torch.quantization.DeQuantStub()
1024+
self.quant_prepare=True
9201025

9211026
def init_parameters(self):
9221027
self.prior_weight_mu.fill_(self.prior_mean)
@@ -947,7 +1052,9 @@ def forward(self, input, return_kl=True):
9471052

9481053
sigma_weight = torch.log1p(torch.exp(self.rho_kernel))
9491054
eps_kernel = self.eps_kernel.data.normal_()
950-
weight = self.mu_kernel + (sigma_weight * eps_kernel)
1055+
tmp_result = sigma_weight * eps_kernel
1056+
weight = self.mu_kernel + tmp_result
1057+
9511058
if return_kl:
9521059
kl_weight = self.kl_div(self.mu_kernel, sigma_weight,
9531060
self.prior_weight_mu, self.prior_weight_sigma)
@@ -964,6 +1071,19 @@ def forward(self, input, return_kl=True):
9641071
out = F.conv_transpose3d(input, weight, bias, self.stride,
9651072
self.padding, self.output_padding,
9661073
self.dilation, self.groups)
1074+
1075+
if self.quant_prepare:
1076+
# quint8 quantstub
1077+
input = self.quint_quant[0](input) # input
1078+
out = self.quint_quant[1](out) # output
1079+
1080+
# qint8 quantstub
1081+
sigma_weight = self.qint_quant[0](sigma_weight) # weight
1082+
mu_kernel = self.qint_quant[1](self.mu_kernel) # weight
1083+
eps_kernel = self.qint_quant[2](eps_kernel) # random variable
1084+
tmp_result =self.qint_quant[3](tmp_result) # multiply activation
1085+
weight = self.qint_quant[4](weight) # add activatation
1086+
9671087
if return_kl:
9681088
if self.bias:
9691089
kl = kl_weight + kl_bias

0 commit comments

Comments
 (0)
Please sign in to comment.