22
22
_register_layout_cls ,
23
23
_get_layout_tensor_constructor ,
24
24
LayoutType ,
25
+ PlainLayoutType ,
25
26
is_device ,
26
27
)
27
- from typing import ClassVar
28
28
from dataclasses import dataclass
29
29
from torchao .utils import TORCH_VERSION_AFTER_2_5
30
30
31
31
aten = torch .ops .aten
32
32
33
- @dataclass (frozen = True )
34
- class PlainLayoutType (LayoutType ):
35
- pass
36
-
37
- @dataclass (frozen = True )
38
- class SemiSparseLayoutType (LayoutType ):
39
-
40
- def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
41
- # prune to 2:4 if not already
42
- temp = input .detach ()
43
- pruning_inds = temp .abs ().view (- 1 , 4 ).argsort (dim = 1 )[:, :2 ]
44
- temp .view (- 1 , 4 ).scatter_ (1 , pruning_inds , value = 0 )
45
- return temp
46
-
47
-
48
- @dataclass (frozen = True )
49
- class TensorCoreTiledLayoutType (LayoutType ):
50
- inner_k_tiles : int = 8
51
-
52
- def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
53
- orig_out_features , orig_in_features = input .shape
54
- in_features = find_multiple (orig_in_features , 1024 )
55
- out_features = find_multiple (orig_out_features , 8 )
56
- input = torch .nn .functional .pad (
57
- input ,
58
- (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
59
- )
60
- return input
61
-
62
- def extra_repr (self ):
63
- return f"inner_k_tiles={ self .inner_k_tiles } "
64
-
65
-
66
- def _aqt_is_int8 (aqt ):
67
- """Check if an AffineQuantizedTensor is int8 quantized Tensor"""
68
- return (
69
- aqt .layout_tensor .dtype == torch .int8 and
70
- aqt .quant_min is None or aqt .quant_min == - 128 and
71
- aqt .quant_max is None or aqt .quant_max == 127
72
- )
73
-
74
- def _aqt_is_int8_reduced_range (aqt ):
75
- return (
76
- aqt .layout_tensor .dtype == torch .int8 and
77
- aqt .quant_min == - 127 and
78
- aqt .quant_max is None or aqt .quant_max == 127
79
- )
80
-
81
- def _aqt_is_uint4 (aqt ):
82
- """Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
83
- # TODO: use torch.uint4
84
- return (
85
- aqt .layout_tensor .dtype == torch .int32 and
86
- aqt .quant_min is None or aqt .quant_min == 0 and
87
- aqt .quant_max is None or aqt .quant_max == 15
88
- )
89
-
33
+ ###############################
34
+ # Base Layout Tensor Subclass #
35
+ ###############################
90
36
class AQTLayout (torch .Tensor ):
91
37
"""
92
38
Base class for the layout tensor for `AffineQuantizedTensor`
@@ -126,6 +72,10 @@ def _get_to_kwargs(self, *args, **kwargs):
126
72
}
127
73
return kwargs
128
74
75
+ ##############################
76
+ # Tensor Subclass Definition #
77
+ ##############################
78
+
129
79
class AffineQuantizedTensor (torch .Tensor ):
130
80
"""
131
81
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
@@ -337,7 +287,6 @@ def _apply_fn_to_data(self, fn):
337
287
strides = self .stride (),
338
288
)
339
289
340
-
341
290
implements = classmethod (_implements )
342
291
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
343
292
# 1. we'll add cpu/cuda version (int4mm etc.)
@@ -353,14 +302,46 @@ def _apply_fn_to_data(self, fn):
353
302
__torch_dispatch__ = classmethod (_dispatch__torch_dispatch__ )
354
303
__torch_function__ = classmethod (_dispatch__torch_function__ )
355
304
356
- implements = AffineQuantizedTensor .implements
305
+
306
+ ######################################################
307
+ # LayoutType and Layout Tensor Subclass Registration #
308
+ ######################################################
357
309
358
310
def register_layout_cls (layout_type_class : type (LayoutType )):
359
311
return _register_layout_cls (AffineQuantizedTensor , layout_type_class )
360
312
361
313
def get_layout_tensor_constructor (layout_type_class : type (LayoutType )):
362
314
return _get_layout_tensor_constructor (AffineQuantizedTensor , layout_type_class )
363
315
316
+ @dataclass (frozen = True )
317
+ class SemiSparseLayoutType (LayoutType ):
318
+
319
+ def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
320
+ # prune to 2:4 if not already
321
+ temp = input .detach ()
322
+ pruning_inds = temp .abs ().view (- 1 , 4 ).argsort (dim = 1 )[:, :2 ]
323
+ temp .view (- 1 , 4 ).scatter_ (1 , pruning_inds , value = 0 )
324
+ return temp
325
+
326
+
327
+ @dataclass (frozen = True )
328
+ class TensorCoreTiledLayoutType (LayoutType ):
329
+ inner_k_tiles : int = 8
330
+
331
+ def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
332
+ orig_out_features , orig_in_features = input .shape
333
+ in_features = find_multiple (orig_in_features , 1024 )
334
+ out_features = find_multiple (orig_out_features , 8 )
335
+ input = torch .nn .functional .pad (
336
+ input ,
337
+ (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
338
+ )
339
+ return input
340
+
341
+ def extra_repr (self ):
342
+ return f"inner_k_tiles={ self .inner_k_tiles } "
343
+
344
+
364
345
@register_layout_cls (PlainLayoutType )
365
346
class PlainAQTLayout (AQTLayout ):
366
347
"""
@@ -487,7 +468,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
487
468
)
488
469
489
470
def get_plain (self ):
490
- # Currently we don't have cuSPARSELt expansion routines, so we matmul by
471
+ # Currently we don't have cuSPARSELt expansion routines, so we matmul by
491
472
# the identity matrix to get the original dense matrix. This is slow though.
492
473
cols = self .int_data .numel () * 16 // (10 * self .scale .shape [0 ])
493
474
int_data_expanded = torch ._cslt_sparse_mm (self .int_data ,
@@ -507,7 +488,7 @@ def from_plain(
507
488
assert isinstance (layout_type , SemiSparseLayoutType )
508
489
int_data_compressed = torch ._cslt_compress (int_data )
509
490
return cls (int_data_compressed , scale , zero_point , layout_type )
510
-
491
+
511
492
512
493
@register_layout_cls (TensorCoreTiledLayoutType )
513
494
class TensorCoreTiledAQTLayout (AQTLayout ):
@@ -654,6 +635,34 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
654
635
def get_layout_type (self ) -> LayoutType :
655
636
return self .layout_type
656
637
638
+ #####################################################
639
+ # torch functional and aten operator implementation #
640
+ #####################################################
641
+
642
+ def _aqt_is_int8 (aqt ):
643
+ """Check if an AffineQuantizedTensor is int8 quantized Tensor"""
644
+ return (
645
+ aqt .layout_tensor .dtype == torch .int8 and
646
+ aqt .quant_min is None or aqt .quant_min == - 128 and
647
+ aqt .quant_max is None or aqt .quant_max == 127
648
+ )
649
+
650
+ def _aqt_is_int8_reduced_range (aqt ):
651
+ return (
652
+ aqt .layout_tensor .dtype == torch .int8 and
653
+ aqt .quant_min == - 127 and
654
+ aqt .quant_max is None or aqt .quant_max == 127
655
+ )
656
+
657
+ def _aqt_is_uint4 (aqt ):
658
+ """Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
659
+ # TODO: use torch.uint4
660
+ return (
661
+ aqt .layout_tensor .dtype == torch .int32 and
662
+ aqt .quant_min is None or aqt .quant_min == 0 and
663
+ aqt .quant_max is None or aqt .quant_max == 15
664
+ )
665
+
657
666
def _quantized_linear_op (input_tensor , weight_qtensor , bias ):
658
667
"""
659
668
Quantized version of F.linear operator
@@ -811,8 +820,10 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
811
820
raise NotImplementedError ("No specialized dispatch found for quantized linear op" )
812
821
813
822
823
+ implements = AffineQuantizedTensor .implements
824
+
814
825
@implements (torch .nn .functional .linear )
815
- def _ (func , types , * args , ** kwargs ):
826
+ def _ (func , types , args , kwargs ):
816
827
input_tensor , weight_tensor , bias = (
817
828
args [0 ],
818
829
args [1 ],
@@ -831,7 +842,7 @@ def _(func, types, *args, **kwargs):
831
842
return torch .nn .functional .linear (input_tensor , weight_tensor , bias )
832
843
833
844
@implements ([aten .mm .default , aten .addmm .default ])
834
- def _ (func , types , * args , ** kwargs ):
845
+ def _ (func , types , args , kwargs ):
835
846
if not args [0 ].is_floating_point ():
836
847
raise NotImplementedError (f"{ func } is not implemented for non floating point input" )
837
848
@@ -870,21 +881,21 @@ def _(func, types, *args, **kwargs):
870
881
return func (input_tensor , weight_tensor )
871
882
872
883
@implements ([aten .detach .default ])
873
- def _ (func , types , * args , ** kwargs ):
884
+ def _ (func , types , args , kwargs ):
874
885
return return_and_correct_aliasing (
875
886
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
876
887
)
877
888
878
889
879
890
@implements ([aten .clone .default ])
880
- def _ (func , types , * args , ** kwargs ):
891
+ def _ (func , types , args , kwargs ):
881
892
return return_and_correct_aliasing (
882
893
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
883
894
)
884
895
885
896
886
897
@implements ([aten ._to_copy .default ])
887
- def _ (func , types , * args , ** kwargs ):
898
+ def _ (func , types , args , kwargs ):
888
899
return return_and_correct_aliasing (
889
900
func ,
890
901
args ,
@@ -893,7 +904,7 @@ def _(func, types, *args, **kwargs):
893
904
)
894
905
895
906
@implements ([aten .t .default ])
896
- def _ (func , types , * args , ** kwargs ):
907
+ def _ (func , types , args , kwargs ):
897
908
block_size = args [0 ].block_size
898
909
assert len (block_size ) == 2
899
910
transposed_block_size = (block_size [1 ], block_size [0 ])
0 commit comments