4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , Optional
7
+ from typing import Any , Optional , Tuple
8
8
9
9
import torch
10
10
import torch .nn .functional as F
@@ -196,15 +196,40 @@ def convert(
196
196
"""
197
197
self ._convert_helper (model )
198
198
return model
199
+
200
+ @staticmethod
201
+ def quantize_weights (
202
+ weight : torch .Tensor ,
203
+ bit_width : int ,
204
+ group_size : int ,
205
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
206
+ """
207
+ Helper function to quantize weights
208
+ """
209
+ (qmin , qmax ) = _get_qmin_qmax (bit_width )
210
+ (s , zp ) = get_group_qparams_symmetric (
211
+ weight , bit_width , group_size
212
+ )
213
+ from torchao ._executorch_ops import (
214
+ _quantized_decomposed_quantize_per_channel_group_wrapper ,
215
+ )
216
+ q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
217
+ weight ,
218
+ s ,
219
+ zp ,
220
+ qmin ,
221
+ qmax ,
222
+ torch .int8 ,
223
+ group_size ,
224
+ )
225
+ return (q_weight , s , zp )
226
+
199
227
200
228
def _convert_helper (self , module : torch .nn .Module ):
201
229
"""
202
230
Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
203
231
modules with `Int4WeightOnlyEmbedding`
204
232
"""
205
- from torchao ._executorch_ops import (
206
- _quantized_decomposed_quantize_per_channel_group_wrapper ,
207
- )
208
233
209
234
for name , child in module .named_children ():
210
235
if isinstance (child , Int4WeightOnlyQATEmbedding ):
@@ -230,20 +255,8 @@ def _convert_helper(self, module: torch.nn.Module):
230
255
)
231
256
setattr (module , name , quantized_embedding )
232
257
258
+ q_weight , s , zp = self .quantize_weights (child .weight , self .bit_width , group_size )
233
259
# Load weights and qparams into quantized embedding
234
- (qmin , qmax ) = _get_qmin_qmax (self .bit_width )
235
- (s , zp ) = get_group_qparams_symmetric (
236
- child .weight , self .bit_width , group_size
237
- )
238
- q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
239
- child .weight ,
240
- s ,
241
- zp ,
242
- qmin ,
243
- qmax ,
244
- torch .int8 ,
245
- group_size ,
246
- )
247
260
quantized_embedding .weight = q_weight
248
261
quantized_embedding .scale = s .to (scale_precision )
249
262
quantized_embedding .zero_point = zp .to (zero_point_precision )
0 commit comments