Skip to content

Commit 4805efd

Browse files
authored
create staticmethod for quantizing weights of QATLinear and QATEmbedding
Differential Revision: D73201409 Pull Request resolved: #2079
1 parent 7b05105 commit 4805efd

File tree

2 files changed

+64
-39
lines changed

2 files changed

+64
-39
lines changed

torchao/quantization/qat/embedding.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional
7+
from typing import Any, Optional, Tuple
88

99
import torch
1010
import torch.nn.functional as F
@@ -196,15 +196,40 @@ def convert(
196196
"""
197197
self._convert_helper(model)
198198
return model
199+
200+
@staticmethod
201+
def quantize_weights(
202+
weight: torch.Tensor,
203+
bit_width: int,
204+
group_size: int,
205+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
206+
"""
207+
Helper function to quantize weights
208+
"""
209+
(qmin, qmax) = _get_qmin_qmax(bit_width)
210+
(s, zp) = get_group_qparams_symmetric(
211+
weight, bit_width, group_size
212+
)
213+
from torchao._executorch_ops import (
214+
_quantized_decomposed_quantize_per_channel_group_wrapper,
215+
)
216+
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
217+
weight,
218+
s,
219+
zp,
220+
qmin,
221+
qmax,
222+
torch.int8,
223+
group_size,
224+
)
225+
return (q_weight, s, zp)
226+
199227

200228
def _convert_helper(self, module: torch.nn.Module):
201229
"""
202230
Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
203231
modules with `Int4WeightOnlyEmbedding`
204232
"""
205-
from torchao._executorch_ops import (
206-
_quantized_decomposed_quantize_per_channel_group_wrapper,
207-
)
208233

209234
for name, child in module.named_children():
210235
if isinstance(child, Int4WeightOnlyQATEmbedding):
@@ -230,20 +255,8 @@ def _convert_helper(self, module: torch.nn.Module):
230255
)
231256
setattr(module, name, quantized_embedding)
232257

258+
q_weight, s, zp = self.quantize_weights(child.weight, self.bit_width, group_size)
233259
# Load weights and qparams into quantized embedding
234-
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
235-
(s, zp) = get_group_qparams_symmetric(
236-
child.weight, self.bit_width, group_size
237-
)
238-
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
239-
child.weight,
240-
s,
241-
zp,
242-
qmin,
243-
qmax,
244-
torch.int8,
245-
group_size,
246-
)
247260
quantized_embedding.weight = q_weight
248261
quantized_embedding.scale = s.to(scale_precision)
249262
quantized_embedding.zero_point = zp.to(zero_point_precision)

torchao/quantization/qat/linear.py

+34-22
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional
7+
from typing import Any, Optional, Tuple
88

99
import torch
1010
import torch.nn.functional as F
@@ -197,6 +197,36 @@ def convert(
197197
) -> torch.nn.Module:
198198
self._convert_qat_linear_8da4w(model)
199199
return model
200+
201+
@staticmethod
202+
def quantize_weights(
203+
weight: torch.Tensor,
204+
group_size: int,
205+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
206+
"""
207+
Helper function to quantize weights
208+
"""
209+
# Load weights and qparams into quantized linear
210+
n_bit = 4
211+
(qmin, qmax) = _get_qmin_qmax(n_bit)
212+
(s, zp) = get_group_qparams_symmetric(
213+
weight, n_bit, group_size
214+
)
215+
from torchao._executorch_ops import (
216+
_quantized_decomposed_quantize_per_channel_group_wrapper,
217+
)
218+
219+
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
220+
weight,
221+
s,
222+
zp,
223+
qmin,
224+
qmax,
225+
torch.int8,
226+
group_size,
227+
)
228+
return (q_weight, s, zp)
229+
200230

201231
def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
202232
"""
@@ -215,28 +245,10 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
215245
)
216246
setattr(module, name, quantized_linear)
217247

218-
# Load weights and qparams into quantized linear
219-
n_bit = 4
220-
(qmin, qmax) = _get_qmin_qmax(n_bit)
221-
(s, zp) = get_group_qparams_symmetric(
222-
child.weight, n_bit, config.group_size
223-
)
224-
from torchao._executorch_ops import (
225-
_quantized_decomposed_quantize_per_channel_group_wrapper,
226-
)
227-
228-
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
229-
child.weight,
230-
s,
231-
zp,
232-
qmin,
233-
qmax,
234-
torch.int8,
235-
config.group_size,
236-
)
248+
q_weight, scales, zeros = self.quantize_weights(child.weight, config.group_size)
237249
quantized_linear.weight = q_weight
238-
quantized_linear.scales = s
239-
quantized_linear.zeros = zp
250+
quantized_linear.scales = scales
251+
quantized_linear.zeros = zeros
240252
if child.bias is not None:
241253
quantized_linear.bias = child.bias
242254
else:

0 commit comments

Comments
 (0)