Skip to content

Commit 0844de3

Browse files
authored
Add developer guide code to tutorials (#588)
Summary: Moved notebook: https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA from #391 to `tutorials` folder so that code can be executed while we develop new APIs/utils and being kept up to date Test Plan: python Reviewers: python tutorials/developer_api_guide.py regression tests: python test/quantization/test_quant_api.py python test/integration/test_integraton.py Subscribers: Tasks: Tags:
1 parent 08024c6 commit 0844de3

File tree

8 files changed

+483
-92
lines changed

8 files changed

+483
-92
lines changed

torchao/dtypes/affine_quantized_tensor.py

+79-68
Original file line numberDiff line numberDiff line change
@@ -22,71 +22,17 @@
2222
_register_layout_cls,
2323
_get_layout_tensor_constructor,
2424
LayoutType,
25+
PlainLayoutType,
2526
is_device,
2627
)
27-
from typing import ClassVar
2828
from dataclasses import dataclass
2929
from torchao.utils import TORCH_VERSION_AFTER_2_5
3030

3131
aten = torch.ops.aten
3232

33-
@dataclass(frozen=True)
34-
class PlainLayoutType(LayoutType):
35-
pass
36-
37-
@dataclass(frozen=True)
38-
class SemiSparseLayoutType(LayoutType):
39-
40-
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
41-
# prune to 2:4 if not already
42-
temp = input.detach()
43-
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
44-
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
45-
return temp
46-
47-
48-
@dataclass(frozen=True)
49-
class TensorCoreTiledLayoutType(LayoutType):
50-
inner_k_tiles: int = 8
51-
52-
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
53-
orig_out_features, orig_in_features = input.shape
54-
in_features = find_multiple(orig_in_features, 1024)
55-
out_features = find_multiple(orig_out_features, 8)
56-
input = torch.nn.functional.pad(
57-
input,
58-
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
59-
)
60-
return input
61-
62-
def extra_repr(self):
63-
return f"inner_k_tiles={self.inner_k_tiles}"
64-
65-
66-
def _aqt_is_int8(aqt):
67-
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
68-
return (
69-
aqt.layout_tensor.dtype == torch.int8 and
70-
aqt.quant_min is None or aqt.quant_min == -128 and
71-
aqt.quant_max is None or aqt.quant_max == 127
72-
)
73-
74-
def _aqt_is_int8_reduced_range(aqt):
75-
return (
76-
aqt.layout_tensor.dtype == torch.int8 and
77-
aqt.quant_min == -127 and
78-
aqt.quant_max is None or aqt.quant_max == 127
79-
)
80-
81-
def _aqt_is_uint4(aqt):
82-
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
83-
# TODO: use torch.uint4
84-
return (
85-
aqt.layout_tensor.dtype == torch.int32 and
86-
aqt.quant_min is None or aqt.quant_min == 0 and
87-
aqt.quant_max is None or aqt.quant_max == 15
88-
)
89-
33+
###############################
34+
# Base Layout Tensor Subclass #
35+
###############################
9036
class AQTLayout(torch.Tensor):
9137
"""
9238
Base class for the layout tensor for `AffineQuantizedTensor`
@@ -126,6 +72,10 @@ def _get_to_kwargs(self, *args, **kwargs):
12672
}
12773
return kwargs
12874

75+
##############################
76+
# Tensor Subclass Definition #
77+
##############################
78+
12979
class AffineQuantizedTensor(torch.Tensor):
13080
"""
13181
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
@@ -337,7 +287,6 @@ def _apply_fn_to_data(self, fn):
337287
strides=self.stride(),
338288
)
339289

340-
341290
implements = classmethod(_implements)
342291
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
343292
# 1. we'll add cpu/cuda version (int4mm etc.)
@@ -353,14 +302,46 @@ def _apply_fn_to_data(self, fn):
353302
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
354303
__torch_function__ = classmethod(_dispatch__torch_function__)
355304

356-
implements = AffineQuantizedTensor.implements
305+
306+
######################################################
307+
# LayoutType and Layout Tensor Subclass Registration #
308+
######################################################
357309

358310
def register_layout_cls(layout_type_class: type(LayoutType)):
359311
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)
360312

361313
def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
362314
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class)
363315

316+
@dataclass(frozen=True)
317+
class SemiSparseLayoutType(LayoutType):
318+
319+
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
320+
# prune to 2:4 if not already
321+
temp = input.detach()
322+
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
323+
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
324+
return temp
325+
326+
327+
@dataclass(frozen=True)
328+
class TensorCoreTiledLayoutType(LayoutType):
329+
inner_k_tiles: int = 8
330+
331+
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
332+
orig_out_features, orig_in_features = input.shape
333+
in_features = find_multiple(orig_in_features, 1024)
334+
out_features = find_multiple(orig_out_features, 8)
335+
input = torch.nn.functional.pad(
336+
input,
337+
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
338+
)
339+
return input
340+
341+
def extra_repr(self):
342+
return f"inner_k_tiles={self.inner_k_tiles}"
343+
344+
364345
@register_layout_cls(PlainLayoutType)
365346
class PlainAQTLayout(AQTLayout):
366347
"""
@@ -487,7 +468,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
487468
)
488469

489470
def get_plain(self):
490-
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
471+
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
491472
# the identity matrix to get the original dense matrix. This is slow though.
492473
cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0])
493474
int_data_expanded = torch._cslt_sparse_mm(self.int_data,
@@ -507,7 +488,7 @@ def from_plain(
507488
assert isinstance(layout_type, SemiSparseLayoutType)
508489
int_data_compressed = torch._cslt_compress(int_data)
509490
return cls(int_data_compressed, scale, zero_point, layout_type)
510-
491+
511492

512493
@register_layout_cls(TensorCoreTiledLayoutType)
513494
class TensorCoreTiledAQTLayout(AQTLayout):
@@ -654,6 +635,34 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
654635
def get_layout_type(self) -> LayoutType:
655636
return self.layout_type
656637

638+
#####################################################
639+
# torch functional and aten operator implementation #
640+
#####################################################
641+
642+
def _aqt_is_int8(aqt):
643+
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
644+
return (
645+
aqt.layout_tensor.dtype == torch.int8 and
646+
aqt.quant_min is None or aqt.quant_min == -128 and
647+
aqt.quant_max is None or aqt.quant_max == 127
648+
)
649+
650+
def _aqt_is_int8_reduced_range(aqt):
651+
return (
652+
aqt.layout_tensor.dtype == torch.int8 and
653+
aqt.quant_min == -127 and
654+
aqt.quant_max is None or aqt.quant_max == 127
655+
)
656+
657+
def _aqt_is_uint4(aqt):
658+
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
659+
# TODO: use torch.uint4
660+
return (
661+
aqt.layout_tensor.dtype == torch.int32 and
662+
aqt.quant_min is None or aqt.quant_min == 0 and
663+
aqt.quant_max is None or aqt.quant_max == 15
664+
)
665+
657666
def _quantized_linear_op(input_tensor, weight_qtensor, bias):
658667
"""
659668
Quantized version of F.linear operator
@@ -811,8 +820,10 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
811820
raise NotImplementedError("No specialized dispatch found for quantized linear op")
812821

813822

823+
implements = AffineQuantizedTensor.implements
824+
814825
@implements(torch.nn.functional.linear)
815-
def _(func, types, *args, **kwargs):
826+
def _(func, types, args, kwargs):
816827
input_tensor, weight_tensor, bias = (
817828
args[0],
818829
args[1],
@@ -831,7 +842,7 @@ def _(func, types, *args, **kwargs):
831842
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
832843

833844
@implements([aten.mm.default, aten.addmm.default])
834-
def _(func, types, *args, **kwargs):
845+
def _(func, types, args, kwargs):
835846
if not args[0].is_floating_point():
836847
raise NotImplementedError(f"{func} is not implemented for non floating point input")
837848

@@ -870,21 +881,21 @@ def _(func, types, *args, **kwargs):
870881
return func(input_tensor, weight_tensor)
871882

872883
@implements([aten.detach.default])
873-
def _(func, types, *args, **kwargs):
884+
def _(func, types, args, kwargs):
874885
return return_and_correct_aliasing(
875886
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
876887
)
877888

878889

879890
@implements([aten.clone.default])
880-
def _(func, types, *args, **kwargs):
891+
def _(func, types, args, kwargs):
881892
return return_and_correct_aliasing(
882893
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
883894
)
884895

885896

886897
@implements([aten._to_copy.default])
887-
def _(func, types, *args, **kwargs):
898+
def _(func, types, args, kwargs):
888899
return return_and_correct_aliasing(
889900
func,
890901
args,
@@ -893,7 +904,7 @@ def _(func, types, *args, **kwargs):
893904
)
894905

895906
@implements([aten.t.default])
896-
def _(func, types, *args, **kwargs):
907+
def _(func, types, args, kwargs):
897908
block_size = args[0].block_size
898909
assert len(block_size) == 2
899910
transposed_block_size = (block_size[1], block_size[0])

torchao/dtypes/utils.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def _(func, types, args, kwargs):
3232
def decorator(func):
3333
for op in aten_ops_or_torch_fns:
3434
@functools.wraps(op)
35-
def wrapper(*args, **kwargs):
36-
return func(*args, **kwargs)
35+
def wrapper(f, types, args, kwargs):
36+
return func(f, types, args, kwargs)
3737

3838
cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
3939
return func
@@ -50,7 +50,7 @@ class MyTensor(torch.Tensor):
5050
kwargs = {} if kwargs is None else kwargs
5151
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
5252
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
53-
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
53+
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
5454

5555
with torch._C.DisableTorchFunctionSubclass():
5656
return func(*args, **kwargs)
@@ -65,7 +65,7 @@ class MyTensor(torch.Tensor):
6565
"""
6666
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
6767
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
68-
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
68+
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
6969

7070
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")
7171

@@ -87,6 +87,14 @@ def __repr__(self):
8787
def extra_repr(self) -> str:
8888
return ""
8989

90+
"""
91+
Plain LayoutType, the most basic LayoutType, also has no extra metadata, will typically be the default
92+
"""
93+
@dataclass(frozen=True)
94+
class PlainLayoutType(LayoutType):
95+
pass
96+
97+
9098
"""
9199
layout tensor constructor registration for different tensor subclassesa
92100

torchao/prototype/low_bit_optim/subclass_4bit.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __repr__(self):
8989

9090

9191
@OptimState4bit.implements(aten.copy_.default)
92-
def _(func, types, *args, **kwargs):
92+
def _(func, types, args, kwargs):
9393
dst = args[0]
9494
src = args[1]
9595

@@ -116,14 +116,14 @@ def _(func, types, *args, **kwargs):
116116

117117

118118
@OptimState4bit.implements(aten.lerp.Scalar)
119-
def _(func, types, *args, **kwargs):
119+
def _(func, types, args, kwargs):
120120
args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args]
121121
return func(*args, **kwargs)
122122

123123

124124
# this is needed for DTensor.from_local() and for flattening tensor
125125
@OptimState4bit.implements(aten.view.default)
126-
def _(func, types, *args, **kwargs):
126+
def _(func, types, args, kwargs):
127127
x, shape = args
128128

129129
if tuple(x.shape) == tuple(shape):
@@ -142,7 +142,7 @@ def _(func, types, *args, **kwargs):
142142
c10d_functional.wait_tensor.default,
143143
_c10d_functional.wait_tensor.default,
144144
])
145-
def _(func, types, *args, **kwargs):
145+
def _(func, types, args, kwargs):
146146
x = args[0]
147147
if not isinstance(x, OptimState4bit):
148148
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")

torchao/prototype/low_bit_optim/subclass_8bit.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __repr__(self):
7575

7676

7777
@OptimState8bit.implements(aten.copy_.default)
78-
def _(func, types, *args, **kwargs):
78+
def _(func, types, args, kwargs):
7979
dst = args[0]
8080
src = args[1]
8181

@@ -98,14 +98,14 @@ def _(func, types, *args, **kwargs):
9898

9999

100100
@OptimState8bit.implements(aten.lerp.Scalar)
101-
def _(func, types, *args, **kwargs):
101+
def _(func, types, args, kwargs):
102102
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
103103
return func(*args, **kwargs)
104104

105105

106106
# this is needed for DTensor.from_local()
107107
@OptimState8bit.implements(aten.view.default)
108-
def _(func, types, *args, **kwargs):
108+
def _(func, types, args, kwargs):
109109
x, shape = args
110110
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)
111111

@@ -117,7 +117,7 @@ def _(func, types, *args, **kwargs):
117117
c10d_functional.wait_tensor.default,
118118
_c10d_functional.wait_tensor.default,
119119
])
120-
def _(func, types, *args, **kwargs):
120+
def _(func, types, args, kwargs):
121121
x = args[0]
122122
if not isinstance(x, OptimState8bit):
123123
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")

torchao/prototype/low_bit_optim/subclass_fp8.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __repr__(self):
8181

8282

8383
@OptimStateFp8.implements(aten.copy_.default)
84-
def _(func, types, *args, **kwargs):
84+
def _(func, types, args, kwargs):
8585
dst = args[0]
8686
src = args[1]
8787

@@ -102,14 +102,14 @@ def _(func, types, *args, **kwargs):
102102

103103

104104
@OptimStateFp8.implements(aten.lerp.Scalar)
105-
def _(func, types, *args, **kwargs):
105+
def _(func, types, args, kwargs):
106106
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
107107
return func(*args, **kwargs)
108108

109109

110110
# this is needed for DTensor.from_local()
111111
@OptimStateFp8.implements(aten.view.default)
112-
def _(func, types, *args, **kwargs):
112+
def _(func, types, args, kwargs):
113113
x, shape = args
114114
return OptimStateFp8(x.codes.view(shape), x.scale)
115115

@@ -121,7 +121,7 @@ def _(func, types, *args, **kwargs):
121121
c10d_functional.wait_tensor.default,
122122
_c10d_functional.wait_tensor.default,
123123
])
124-
def _(func, types, *args, **kwargs):
124+
def _(func, types, args, kwargs):
125125
x = args[0]
126126
if not isinstance(x, OptimStateFp8):
127127
raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}")

0 commit comments

Comments
 (0)