|
19 | 19 | get_symmetric_quantization_config,
|
20 | 20 | )
|
21 | 21 | from torch.testing._internal import common_utils
|
| 22 | +from torch.testing._internal.common_quantization import TestHelperModules |
22 | 23 | from torch.testing._internal.common_utils import TestCase
|
23 | 24 |
|
24 | 25 | from torchao import quantize_
|
|
28 | 29 | AffineQuantizedTensor,
|
29 | 30 | Int4CPULayout,
|
30 | 31 | Int4XPULayout,
|
| 32 | + PlainLayout, |
| 33 | + QDQLayout, |
| 34 | + TensorCoreTiledLayout, |
| 35 | +) |
| 36 | +from torchao.quantization import ( |
| 37 | + LinearActivationQuantizedTensor, |
| 38 | + PerGroup, |
31 | 39 | )
|
32 |
| -from torchao.quantization import LinearActivationQuantizedTensor |
33 | 40 | from torchao.quantization.quant_api import (
|
| 41 | + AOPerModuleConfig, |
| 42 | + Int4WeightOnlyConfig, |
| 43 | + Int8DynamicActivationInt4WeightConfig, |
| 44 | + Int8WeightOnlyConfig, |
| 45 | + IntxWeightOnlyConfig, |
34 | 46 | Quantizer,
|
35 | 47 | TwoStepQuantizer,
|
36 | 48 | _replace_with_custom_fn_if_matches_filter,
|
@@ -933,6 +945,66 @@ def test_workflow_e2e_numerics(self, config):
|
933 | 945 | sqnr = compute_error(y_ref, y_q)
|
934 | 946 | assert sqnr >= 16.5, f"SQNR {sqnr} is too low"
|
935 | 947 |
|
| 948 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 949 | + def test_ao_per_module_config_default(self): |
| 950 | + config1 = Int4WeightOnlyConfig(group_size=32) |
| 951 | + config2 = Int8WeightOnlyConfig() |
| 952 | + config = AOPerModuleConfig({"_default": config1, "linear2": config2}) |
| 953 | + model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) |
| 954 | + example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) |
| 955 | + quantize_(model, config) |
| 956 | + model(*example_inputs) |
| 957 | + assert isinstance(model.linear1.weight, AffineQuantizedTensor) |
| 958 | + assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) |
| 959 | + assert isinstance(model.linear2.weight, AffineQuantizedTensor) |
| 960 | + assert isinstance(model.linear2.weight._layout, PlainLayout) |
| 961 | + |
| 962 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 963 | + def test_ao_per_module_config_module_name(self): |
| 964 | + config1 = Int4WeightOnlyConfig(group_size=32) |
| 965 | + config2 = Int8WeightOnlyConfig() |
| 966 | + config = AOPerModuleConfig({"linear1": config1, "linear2": config2}) |
| 967 | + model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) |
| 968 | + example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) |
| 969 | + quantize_(model, config) |
| 970 | + model(*example_inputs) |
| 971 | + assert isinstance(model.linear1.weight, AffineQuantizedTensor) |
| 972 | + assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) |
| 973 | + assert isinstance(model.linear2.weight, AffineQuantizedTensor) |
| 974 | + assert isinstance(model.linear2.weight._layout, PlainLayout) |
| 975 | + |
| 976 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch 2.6+") |
| 977 | + def test_ao_per_module_config_embedding_linear(self): |
| 978 | + weight_dtype = torch.int8 |
| 979 | + granularity = PerGroup(8) |
| 980 | + mapping_type = MappingType.SYMMETRIC |
| 981 | + embedding_config = IntxWeightOnlyConfig( |
| 982 | + weight_dtype=weight_dtype, |
| 983 | + granularity=granularity, |
| 984 | + mapping_type=mapping_type, |
| 985 | + scale_dtype=None, |
| 986 | + ) |
| 987 | + # example model linear is Linear(16, 8) |
| 988 | + linear_config = Int8DynamicActivationInt4WeightConfig(group_size=16) |
| 989 | + |
| 990 | + config = AOPerModuleConfig({"emb": embedding_config, "linear": linear_config}) |
| 991 | + indices = torch.randint(0, 10, (32,)) |
| 992 | + indices = indices.unsqueeze(0) |
| 993 | + example_inputs = (indices,) |
| 994 | + model = TestHelperModules.EmbeddingConvLinearModule().eval() |
| 995 | + model(*example_inputs) |
| 996 | + quantize_( |
| 997 | + model, |
| 998 | + config, |
| 999 | + filter_fn=lambda x, fqn: isinstance(x, torch.nn.Linear) |
| 1000 | + or isinstance(x, torch.nn.Embedding), |
| 1001 | + ) |
| 1002 | + model(*example_inputs) |
| 1003 | + |
| 1004 | + assert isinstance(model.emb.weight, AffineQuantizedTensor) |
| 1005 | + assert isinstance(model.emb.weight._layout, QDQLayout) |
| 1006 | + assert isinstance(model.linear.weight, LinearActivationQuantizedTensor) |
| 1007 | + |
936 | 1008 |
|
937 | 1009 | class TestMultiTensorFlow(TestCase):
|
938 | 1010 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
|
|
0 commit comments