From 1d6f7d4614cd7e22af44361da30d2ffa6e8ab4b1 Mon Sep 17 00:00:00 2001 From: Naveen Suda Date: Mon, 21 Apr 2025 12:59:12 -0700 Subject: [PATCH] compare prepared vs. converted outputs for Embedding (#2087) Summary: Fixed the embedding op and updated the test. Reviewed By: telgamal-1, jerryzh168 Differential Revision: D73266106 --- test/quantization/test_qat.py | 11 ++++++++--- torchao/quantization/qat/embedding.py | 5 ++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 2bc58ffdbe..76d253566a 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -808,6 +808,8 @@ def test_qat_4w_embedding(self): _quantized_decomposed_quantize_per_channel_group_wrapper, ) from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer + from torchao.quantization.utils import compute_error + group_size = 256 model = M2() @@ -816,9 +818,9 @@ def test_qat_4w_embedding(self): quantizer = Int4WeightOnlyEmbeddingQATQuantizer(group_size) prepared = quantizer.prepare(model) prepared_embedding_weight = copy.deepcopy(prepared.embedding.weight) - prepared(*x) - converted = quantizer.convert(model) - converted(*x) + prepared_output = prepared(*x) + converted = quantizer.convert(copy.deepcopy(prepared)) + converted_output = converted(*x) # Assert the scales, zero points, and weights are correct after convert qmin, qmax = -8, 7 @@ -837,9 +839,12 @@ def test_qat_4w_embedding(self): torch.int8, group_size, ) + sqnr = compute_error(prepared_output, converted_output).detach().item() torch.testing.assert_close(converted.embedding.weight, q_weight) torch.testing.assert_close(converted.embedding.scale, s) torch.testing.assert_close(converted.embedding.zero_point, zp) + torch.testing.assert_close(sqnr, float('inf')) + def test_fake_quantize_config_granularity(self): """ diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index 02772f05f0..34ff48a953 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -251,6 +251,7 @@ def _convert_helper(self, module: torch.nn.Module): group_size=group_size, scale_precision=scale_precision, zero_point_precision=zero_point_precision, + weight_original_precision=child.weight.dtype, device=child.weight.device, ) setattr(module, name, quantized_embedding) @@ -336,6 +337,7 @@ def __init__( group_size: int = 32, scale_precision: torch.dtype = torch.float32, zero_point_precision: torch.dtype = torch.int32, + weight_original_precision: torch.dtype = torch.float32, device: torch.device = None, ): super().__init__() @@ -354,6 +356,7 @@ def __init__( self.group_size = group_size self.scale_precision = scale_precision self.zero_point_precision = zero_point_precision + self.weight_original_precision = weight_original_precision # currently storing unpacked int8 weights self.register_buffer( @@ -393,7 +396,7 @@ def forward(self, x): qmax, torch.int8, self.group_size, - x.dtype, + self.weight_original_precision, ) return F.embedding( x,