@@ -147,6 +147,15 @@ def __init__(self,
147
147
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
148
148
149
149
self .init_parameters ()
150
+ self .quant_prepare = False
151
+
152
+ def prepare (self ):
153
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
154
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (5 )])
155
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
156
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (2 )])
157
+ self .dequant = torch .quantization .DeQuantStub ()
158
+ self .quant_prepare = True
150
159
151
160
def init_parameters (self ):
152
161
self .prior_weight_mu .data .fill_ (self .prior_mean )
@@ -177,7 +186,9 @@ def forward(self, input, return_kl=True):
177
186
178
187
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
179
188
eps_kernel = self .eps_kernel .data .normal_ ()
180
- weight = self .mu_kernel + (sigma_weight * eps_kernel )
189
+ tmp_result = sigma_weight * eps_kernel
190
+ weight = self .mu_kernel + tmp_result
191
+
181
192
if return_kl :
182
193
kl_weight = self .kl_div (self .mu_kernel , sigma_weight ,
183
194
self .prior_weight_mu , self .prior_weight_sigma )
@@ -193,6 +204,19 @@ def forward(self, input, return_kl=True):
193
204
194
205
out = F .conv1d (input , weight , bias , self .stride , self .padding ,
195
206
self .dilation , self .groups )
207
+
208
+ if self .quant_prepare :
209
+ # quint8 quantstub
210
+ input = self .quint_quant [0 ](input ) # input
211
+ out = self .quint_quant [1 ](out ) # output
212
+
213
+ # qint8 quantstub
214
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
215
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
216
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
217
+ tmp_result = self .qint_quant [3 ](tmp_result ) # multiply activation
218
+ weight = self .qint_quant [4 ](weight ) # add activatation
219
+
196
220
if return_kl :
197
221
if self .bias :
198
222
kl = kl_weight + kl_bias
@@ -470,6 +494,15 @@ def __init__(self,
470
494
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
471
495
472
496
self .init_parameters ()
497
+ self .quant_prepare = False
498
+
499
+ def prepare (self ):
500
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
501
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (5 )])
502
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
503
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (2 )])
504
+ self .dequant = torch .quantization .DeQuantStub ()
505
+ self .quant_prepare = True
473
506
474
507
def init_parameters (self ):
475
508
self .prior_weight_mu .fill_ (self .prior_mean )
@@ -500,7 +533,9 @@ def forward(self, input, return_kl=True):
500
533
501
534
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
502
535
eps_kernel = self .eps_kernel .data .normal_ ()
503
- weight = self .mu_kernel + (sigma_weight * eps_kernel )
536
+ tmp_result = sigma_weight * eps_kernel
537
+ weight = self .mu_kernel + tmp_result
538
+
504
539
if return_kl :
505
540
kl_weight = self .kl_div (self .mu_kernel , sigma_weight ,
506
541
self .prior_weight_mu , self .prior_weight_sigma )
@@ -516,6 +551,19 @@ def forward(self, input, return_kl=True):
516
551
517
552
out = F .conv3d (input , weight , bias , self .stride , self .padding ,
518
553
self .dilation , self .groups )
554
+
555
+ if self .quant_prepare :
556
+ # quint8 quantstub
557
+ input = self .quint_quant [0 ](input ) # input
558
+ out = self .quint_quant [1 ](out ) # output
559
+
560
+ # qint8 quantstub
561
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
562
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
563
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
564
+ tmp_result = self .qint_quant [3 ](tmp_result ) # multiply activation
565
+ weight = self .qint_quant [4 ](weight ) # add activatation
566
+
519
567
if return_kl :
520
568
if self .bias :
521
569
kl = kl_weight + kl_bias
@@ -614,6 +662,15 @@ def __init__(self,
614
662
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
615
663
616
664
self .init_parameters ()
665
+ self .quant_prepare = False
666
+
667
+ def prepare (self ):
668
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
669
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (5 )])
670
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
671
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (2 )])
672
+ self .dequant = torch .quantization .DeQuantStub ()
673
+ self .quant_prepare = True
617
674
618
675
def init_parameters (self ):
619
676
self .prior_weight_mu .fill_ (self .prior_mean )
@@ -644,7 +701,9 @@ def forward(self, input, return_kl=True):
644
701
645
702
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
646
703
eps_kernel = self .eps_kernel .data .normal_ ()
647
- weight = self .mu_kernel + (sigma_weight * eps_kernel )
704
+ tmp_result = sigma_weight * eps_kernel
705
+ weight = self .mu_kernel + tmp_result
706
+
648
707
if return_kl :
649
708
kl_weight = self .kl_div (self .mu_kernel , sigma_weight ,
650
709
self .prior_weight_mu , self .prior_weight_sigma )
@@ -661,6 +720,19 @@ def forward(self, input, return_kl=True):
661
720
out = F .conv_transpose1d (input , weight , bias , self .stride ,
662
721
self .padding , self .output_padding ,
663
722
self .dilation , self .groups )
723
+
724
+ if self .quant_prepare :
725
+ # quint8 quantstub
726
+ input = self .quint_quant [0 ](input ) # input
727
+ out = self .quint_quant [1 ](out ) # output
728
+
729
+ # qint8 quantstub
730
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
731
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
732
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
733
+ tmp_result = self .qint_quant [3 ](tmp_result ) # multiply activation
734
+ weight = self .qint_quant [4 ](weight ) # add activatation
735
+
664
736
if return_kl :
665
737
if self .bias :
666
738
kl = kl_weight + kl_bias
@@ -765,6 +837,15 @@ def __init__(self,
765
837
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
766
838
767
839
self .init_parameters ()
840
+ self .quant_prepare = False
841
+
842
+ def prepare (self ):
843
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
844
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (5 )])
845
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
846
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (2 )])
847
+ self .dequant = torch .quantization .DeQuantStub ()
848
+ self .quant_prepare = True
768
849
769
850
def init_parameters (self ):
770
851
self .prior_weight_mu .fill_ (self .prior_mean )
@@ -795,7 +876,9 @@ def forward(self, input, return_kl=True):
795
876
796
877
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
797
878
eps_kernel = self .eps_kernel .data .normal_ ()
798
- weight = self .mu_kernel + (sigma_weight * eps_kernel )
879
+ tmp_result = sigma_weight * eps_kernel
880
+ weight = self .mu_kernel + tmp_result
881
+
799
882
if return_kl :
800
883
kl_weight = self .kl_div (self .mu_kernel , sigma_weight ,
801
884
self .prior_weight_mu , self .prior_weight_sigma )
@@ -812,6 +895,19 @@ def forward(self, input, return_kl=True):
812
895
out = F .conv_transpose2d (input , weight , bias , self .stride ,
813
896
self .padding , self .output_padding ,
814
897
self .dilation , self .groups )
898
+
899
+ if self .quant_prepare :
900
+ # quint8 quantstub
901
+ input = self .quint_quant [0 ](input ) # input
902
+ out = self .quint_quant [1 ](out ) # output
903
+
904
+ # qint8 quantstub
905
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
906
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
907
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
908
+ tmp_result = self .qint_quant [3 ](tmp_result ) # multiply activation
909
+ weight = self .qint_quant [4 ](weight ) # add activatation
910
+
815
911
if return_kl :
816
912
if self .bias :
817
913
kl = kl_weight + kl_bias
@@ -917,6 +1013,15 @@ def __init__(self,
917
1013
self .register_buffer ('prior_bias_sigma' , None , persistent = False )
918
1014
919
1015
self .init_parameters ()
1016
+ self .quant_prepare = False
1017
+
1018
+ def prepare (self ):
1019
+ self .qint_quant = nn .ModuleList ([torch .quantization .QuantStub (
1020
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ), activation = MinMaxObserver .with_args (dtype = torch .qint8 ,qscheme = torch .per_tensor_symmetric ))) for _ in range (5 )])
1021
+ self .quint_quant = nn .ModuleList ([torch .quantization .QuantStub (
1022
+ QConfig (weight = MinMaxObserver .with_args (dtype = torch .quint8 ), activation = MinMaxObserver .with_args (dtype = torch .quint8 ))) for _ in range (2 )])
1023
+ self .dequant = torch .quantization .DeQuantStub ()
1024
+ self .quant_prepare = True
920
1025
921
1026
def init_parameters (self ):
922
1027
self .prior_weight_mu .fill_ (self .prior_mean )
@@ -947,7 +1052,9 @@ def forward(self, input, return_kl=True):
947
1052
948
1053
sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
949
1054
eps_kernel = self .eps_kernel .data .normal_ ()
950
- weight = self .mu_kernel + (sigma_weight * eps_kernel )
1055
+ tmp_result = sigma_weight * eps_kernel
1056
+ weight = self .mu_kernel + tmp_result
1057
+
951
1058
if return_kl :
952
1059
kl_weight = self .kl_div (self .mu_kernel , sigma_weight ,
953
1060
self .prior_weight_mu , self .prior_weight_sigma )
@@ -964,6 +1071,19 @@ def forward(self, input, return_kl=True):
964
1071
out = F .conv_transpose3d (input , weight , bias , self .stride ,
965
1072
self .padding , self .output_padding ,
966
1073
self .dilation , self .groups )
1074
+
1075
+ if self .quant_prepare :
1076
+ # quint8 quantstub
1077
+ input = self .quint_quant [0 ](input ) # input
1078
+ out = self .quint_quant [1 ](out ) # output
1079
+
1080
+ # qint8 quantstub
1081
+ sigma_weight = self .qint_quant [0 ](sigma_weight ) # weight
1082
+ mu_kernel = self .qint_quant [1 ](self .mu_kernel ) # weight
1083
+ eps_kernel = self .qint_quant [2 ](eps_kernel ) # random variable
1084
+ tmp_result = self .qint_quant [3 ](tmp_result ) # multiply activation
1085
+ weight = self .qint_quant [4 ](weight ) # add activatation
1086
+
967
1087
if return_kl :
968
1088
if self .bias :
969
1089
kl = kl_weight + kl_bias
0 commit comments