Skip to content

Commit 31d17c0

Browse files
authored
Add AOPerModuleConfig (#2119)
* Summary: This allows: per module config by name, e.g. linear: config1, embedding: config2 skipping config by module name specify default config for all supported modules by default (modules that passes the filter function) ``` config1 = Int4WeightOnlyConfig(group_size=32) config2 = Int8WeightOnlyConfig() config = AOPerModuleConfig({"_default": config1, "linear2": config2}) config1 = Int4WeightOnlyConfig(group_size=32) config2 = Int8WeightOnlyConfig() config = AOPerModuleConfig({"linear1": config1, "linear2": config2}) ``` Test Plan: python test/quantization/test_quant_api.py -k test_ao_per_module_config Reviewers: Subscribers: Tasks: Tags: * ruff * add embedding test and doc update * torch version
1 parent 2fcab01 commit 31d17c0

File tree

3 files changed

+168
-5
lines changed

3 files changed

+168
-5
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111

1212
- repo: https://github.com/astral-sh/ruff-pre-commit
1313
# Ruff version.
14-
rev: v0.6.8
14+
rev: v0.11.6
1515
hooks:
1616
# Run the linter.
1717
- id: ruff

test/quantization/test_quant_api.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_symmetric_quantization_config,
2020
)
2121
from torch.testing._internal import common_utils
22+
from torch.testing._internal.common_quantization import TestHelperModules
2223
from torch.testing._internal.common_utils import TestCase
2324

2425
from torchao import quantize_
@@ -28,9 +29,20 @@
2829
AffineQuantizedTensor,
2930
Int4CPULayout,
3031
Int4XPULayout,
32+
PlainLayout,
33+
QDQLayout,
34+
TensorCoreTiledLayout,
35+
)
36+
from torchao.quantization import (
37+
LinearActivationQuantizedTensor,
38+
PerGroup,
3139
)
32-
from torchao.quantization import LinearActivationQuantizedTensor
3340
from torchao.quantization.quant_api import (
41+
AOPerModuleConfig,
42+
Int4WeightOnlyConfig,
43+
Int8DynamicActivationInt4WeightConfig,
44+
Int8WeightOnlyConfig,
45+
IntxWeightOnlyConfig,
3446
Quantizer,
3547
TwoStepQuantizer,
3648
_replace_with_custom_fn_if_matches_filter,
@@ -933,6 +945,66 @@ def test_workflow_e2e_numerics(self, config):
933945
sqnr = compute_error(y_ref, y_q)
934946
assert sqnr >= 16.5, f"SQNR {sqnr} is too low"
935947

948+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
949+
def test_ao_per_module_config_default(self):
950+
config1 = Int4WeightOnlyConfig(group_size=32)
951+
config2 = Int8WeightOnlyConfig()
952+
config = AOPerModuleConfig({"_default": config1, "linear2": config2})
953+
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
954+
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
955+
quantize_(model, config)
956+
model(*example_inputs)
957+
assert isinstance(model.linear1.weight, AffineQuantizedTensor)
958+
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
959+
assert isinstance(model.linear2.weight, AffineQuantizedTensor)
960+
assert isinstance(model.linear2.weight._layout, PlainLayout)
961+
962+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
963+
def test_ao_per_module_config_module_name(self):
964+
config1 = Int4WeightOnlyConfig(group_size=32)
965+
config2 = Int8WeightOnlyConfig()
966+
config = AOPerModuleConfig({"linear1": config1, "linear2": config2})
967+
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
968+
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
969+
quantize_(model, config)
970+
model(*example_inputs)
971+
assert isinstance(model.linear1.weight, AffineQuantizedTensor)
972+
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
973+
assert isinstance(model.linear2.weight, AffineQuantizedTensor)
974+
assert isinstance(model.linear2.weight._layout, PlainLayout)
975+
976+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch 2.6+")
977+
def test_ao_per_module_config_embedding_linear(self):
978+
weight_dtype = torch.int8
979+
granularity = PerGroup(8)
980+
mapping_type = MappingType.SYMMETRIC
981+
embedding_config = IntxWeightOnlyConfig(
982+
weight_dtype=weight_dtype,
983+
granularity=granularity,
984+
mapping_type=mapping_type,
985+
scale_dtype=None,
986+
)
987+
# example model linear is Linear(16, 8)
988+
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=16)
989+
990+
config = AOPerModuleConfig({"emb": embedding_config, "linear": linear_config})
991+
indices = torch.randint(0, 10, (32,))
992+
indices = indices.unsqueeze(0)
993+
example_inputs = (indices,)
994+
model = TestHelperModules.EmbeddingConvLinearModule().eval()
995+
model(*example_inputs)
996+
quantize_(
997+
model,
998+
config,
999+
filter_fn=lambda x, fqn: isinstance(x, torch.nn.Linear)
1000+
or isinstance(x, torch.nn.Embedding),
1001+
)
1002+
model(*example_inputs)
1003+
1004+
assert isinstance(model.emb.weight, AffineQuantizedTensor)
1005+
assert isinstance(model.emb.weight._layout, QDQLayout)
1006+
assert isinstance(model.linear.weight, LinearActivationQuantizedTensor)
1007+
9361008

9371009
class TestMultiTensorFlow(TestCase):
9381010
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")

torchao/quantization/quant_api.py

+94-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import logging
1919
import types
2020
import warnings
21-
from dataclasses import dataclass
22-
from typing import Any, Callable, Optional, Tuple, Union
21+
from dataclasses import dataclass, field
22+
from typing import Any, Callable, Dict, Optional, Tuple, Union
2323

2424
import torch
2525
import torch.nn as nn
@@ -307,6 +307,52 @@ def _replace_with_custom_fn_if_matches_filter(
307307
return model
308308

309309

310+
def _replace_with_custom_fn_if_matches_filter_with_name(
311+
model,
312+
replacement_fn,
313+
filter_fn,
314+
cur_fqn="",
315+
device=None,
316+
extra_args: Optional[Tuple[Any, ...]] = (),
317+
) -> None:
318+
"""
319+
A variant of _replace_with_custom_fn_if_matches_filter where replacement_fn takes module name as well
320+
...
321+
replacement_fn (Callable[[torch.nn.Module, str], torch.nn.Module]): The function to replace matching modules.
322+
...
323+
324+
Returns:
325+
None
326+
"""
327+
if isinstance(model, Float8Linear):
328+
with torch.device("meta"):
329+
new_module = nn.Linear(model.in_features, model.out_features)
330+
new_module.weight = model.weight
331+
new_module.bias = model.bias
332+
model = new_module
333+
if filter_fn(model, cur_fqn[:-1]):
334+
if device is not None:
335+
model.to(device=device) # move to device before quantization
336+
model = replacement_fn(model, cur_fqn[:-1], *extra_args)
337+
return model
338+
else:
339+
named_children_list = list(model.named_children())
340+
for name, child in named_children_list:
341+
new_child = _replace_with_custom_fn_if_matches_filter_with_name(
342+
child,
343+
replacement_fn,
344+
filter_fn,
345+
f"{cur_fqn}{name}.",
346+
device,
347+
extra_args,
348+
)
349+
if new_child is not child:
350+
setattr(model, name, new_child)
351+
if device is not None:
352+
model.to(device=device) # move parent module to device
353+
return model
354+
355+
310356
def _is_linear(mod, *args):
311357
# avoid circular dependencies
312358
from torchao.quantization.qat.affine_fake_quantized_tensor import (
@@ -547,13 +593,24 @@ def quantize_(
547593
quantize_(m, int4_weight_only(group_size=32))
548594
549595
"""
596+
filter_fn = _is_linear if filter_fn is None else filter_fn
597+
if isinstance(config, AOPerModuleConfig):
598+
_replace_with_custom_fn_if_matches_filter_with_name(
599+
model,
600+
_ao_per_module_config_handler,
601+
filter_fn,
602+
device=device,
603+
extra_args=(config,),
604+
)
605+
return
606+
550607
if isinstance(config, AOBaseConfig):
551608
handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
552609
# for each linear in the model, apply the transform if filtering passes
553610
_replace_with_custom_fn_if_matches_filter(
554611
model,
555612
handler,
556-
_is_linear if filter_fn is None else filter_fn,
613+
filter_fn,
557614
device=device,
558615
extra_args=(config,),
559616
)
@@ -1900,6 +1957,40 @@ def _fpx_weight_only_transform(
19001957
return module
19011958

19021959

1960+
@dataclass
1961+
class AOPerModuleConfig(AOBaseConfig):
1962+
"""Per module configurations for torchao quantize_ API
1963+
1964+
Args:
1965+
`module_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from
1966+
the fully qualified name of module to the AOBaseConfig that we want to apply to the module.
1967+
Also has a special key: "_default", if "_default" is present in the dictionary,
1968+
the config for "_default" will be applied to all the remaining modules that does not have
1969+
per module configuration specified.
1970+
"""
1971+
1972+
module_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field(
1973+
default_factory=dict
1974+
)
1975+
1976+
1977+
def _ao_per_module_config_handler(
1978+
module: torch.nn.Module, module_fqn: str, config: AOPerModuleConfig
1979+
):
1980+
c = config.module_fqn_to_config.get(module_fqn, None)
1981+
# Maybe: we can add module type specific config in the future, in needed
1982+
# fallback to use default if no module specific config is provided
1983+
default_c = config.module_fqn_to_config.get("_default", None)
1984+
if default_c is not None and c is None:
1985+
c = default_c
1986+
1987+
if c is not None:
1988+
handler = _QUANTIZE_CONFIG_HANDLER[type(c)]
1989+
return handler(module, c)
1990+
1991+
return handler(module, c)
1992+
1993+
19031994
if TORCH_VERSION_AT_LEAST_2_5:
19041995
torch.serialization.add_safe_globals(
19051996
[

0 commit comments

Comments
 (0)