1
1
import torch
2
- from typing import Dict , Callable , Any , Tuple , Optional , Union
2
+ from typing import Tuple , Optional , Union
3
+ import torchao .ops
3
4
from collections import defaultdict
4
5
import functools
5
6
import math
39
40
logger = logging .getLogger (__name__ )
40
41
41
42
from torchao .float8 .inference import Float8MMConfig
42
- aten = torch . ops . aten
43
+
43
44
44
45
###############################
45
46
# Base Layout Tensor Subclass #
@@ -489,6 +490,16 @@ class Float8LayoutType(LayoutType):
489
490
mm_config : Optional [Float8MMConfig ] = None
490
491
491
492
493
+ @dataclass (frozen = True )
494
+ class MarlinSparseLayoutType (LayoutType ):
495
+
496
+ # Inject 2:4 sparsity
497
+ def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
498
+ from torchao .sparsity .marlin import inject_24 # avoid circular import
499
+ w_24 , _ = inject_24 (input , * input .shape )
500
+ return w_24
501
+
502
+
492
503
@register_layout_cls (PlainLayoutType )
493
504
class PlainAQTLayout (AQTLayout ):
494
505
"""
@@ -642,6 +653,153 @@ def from_plain(
642
653
return cls (int_data_compressed , scale , zero_point , layout_type )
643
654
644
655
656
+ @register_layout_cls (MarlinSparseLayoutType )
657
+ class MarlinSparseAQTLayout (AQTLayout ):
658
+ """
659
+ Layout storage class for sparse_marlin_24 layout for affine quantized tensor.
660
+
661
+ Can be used with 4 bits and 8 bits quantization.
662
+
663
+ Original marlin documentation and information:
664
+ https://github.com/IST-DASLab/marlin/tree/master
665
+
666
+ Sparse marlin documentation and information:
667
+ https://github.com/IST-DASLab/Sparse-Marlin?tab=readme-ov-file
668
+
669
+ fields:
670
+ original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape
671
+ group_size (int): the group size used to pack the tensor
672
+ num_bits (int): the number of bits used to quantize the tensor
673
+ """
674
+
675
+ implements = classmethod (_implements )
676
+ __torch_dispatch__ = classmethod (_dispatch__torch_dispatch__ )
677
+ __torch_function__ = classmethod (_dispatch__torch_function__ )
678
+
679
+ def __new__ (
680
+ cls ,
681
+ int_data : torch .Tensor ,
682
+ scale : torch .Tensor ,
683
+ zero_point : torch .Tensor ,
684
+ meta : torch .Tensor ,
685
+ layout_type : LayoutType ,
686
+ original_shape : torch .Size ,
687
+ group_size : int ,
688
+ num_bits : int ,
689
+ ):
690
+ kwargs = {}
691
+ kwargs ["device" ] = int_data .device
692
+ kwargs ["layout" ] = (
693
+ kwargs .get ("layout" ) if kwargs .get ("layout" , False ) else int_data .layout
694
+ )
695
+ kwargs ["dtype" ] = int_data .dtype
696
+ kwargs ["requires_grad" ] = False
697
+ shape = int_data .shape
698
+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
699
+
700
+ def __init__ (
701
+ self ,
702
+ int_data : torch .Tensor ,
703
+ scale : torch .Tensor ,
704
+ zero_point : torch .Tensor ,
705
+ meta : torch .Tensor ,
706
+ layout_type : LayoutType ,
707
+ original_shape : torch .Size ,
708
+ group_size : int ,
709
+ num_bits : int ,
710
+ ):
711
+ self .int_data = int_data
712
+ self .scale = scale
713
+ self .zero_point = zero_point
714
+ self .meta = meta
715
+ self .layout_type = layout_type
716
+ self .original_shape = original_shape
717
+ self .group_size = group_size
718
+ self .num_bits = num_bits
719
+
720
+ def get_plain (self ):
721
+ from torchao .sparsity .marlin import unpack_from_marlin_24 # avoid circular import
722
+ int_data_expanded , scales_expanded = unpack_from_marlin_24 (
723
+ self .int_data ,
724
+ self .scale ,
725
+ self .meta ,
726
+ self .original_shape ,
727
+ self .group_size ,
728
+ self .num_bits ,
729
+ )
730
+ return int_data_expanded , scales_expanded , self .zero_point
731
+
732
+ @classmethod
733
+ def from_plain (
734
+ cls ,
735
+ int_data : torch .Tensor ,
736
+ scale : torch .Tensor ,
737
+ zero_point : torch .Tensor ,
738
+ layout_type : LayoutType ,
739
+ ):
740
+ from torchao .sparsity .marlin import pack_to_marlin_24 , const # avoid circular import
741
+ assert isinstance (layout_type , MarlinSparseLayoutType )
742
+
743
+ # Linear layers are (in_features, out_features) but the int_data that is reaching this point
744
+ # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
745
+ # NOTE(reviewers): Please check if this is what I should do.
746
+ q_w_24 = int_data .t ()
747
+ scale = scale .reshape (- 1 , q_w_24 .shape [1 ])
748
+
749
+ if q_w_24 .dtype != torch .int32 :
750
+ raise ValueError ("Only `torch.int32` weights are supported." )
751
+
752
+ in_features , out_features = q_w_24 .shape
753
+ if in_features % 128 != 0 or out_features != 256 == 0 :
754
+ raise ValueError (
755
+ "`in_features` must be divisible by 64 and `out_features` by 256."
756
+ )
757
+
758
+ # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
759
+ # will require a bit more work to get our current quantization flow to work with it.
760
+ # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
761
+ num_bits = 4 if torch .max (q_w_24 ) < 16 else - 1
762
+ if num_bits not in [4 ]:
763
+ raise ValueError (
764
+ f"Only { [4 ]} bits are supported, got { num_bits } ."
765
+ )
766
+
767
+ group_size = in_features // scale .shape [0 ]
768
+ if group_size == 0 :
769
+ group_size = in_features
770
+ assert group_size <= in_features , "Group size must be less than or equal to in_features."
771
+
772
+ if group_size not in const .SUPPORTED_GROUP_SIZES :
773
+ raise ValueError (
774
+ f"Only { const .SUPPORTED_GROUP_SIZES } group sizes are supported, got { group_size } ."
775
+ )
776
+
777
+ # Compress quantized weight to marlin 2:4 format
778
+ marlin_24_q_w_comp , marlin_24_s , meta = pack_to_marlin_24 (q_w_24 , scale , num_bits , group_size )
779
+
780
+ return cls (
781
+ marlin_24_q_w_comp , marlin_24_s , zero_point ,
782
+ meta , layout_type , q_w_24 .shape ,
783
+ group_size , num_bits
784
+ )
785
+
786
+ def get_layout_type (self ) -> LayoutType :
787
+ return self .layout_type
788
+
789
+ def _apply_fn_to_data (self , fn ):
790
+ self .int_data = fn (self .int_data )
791
+ self .scale = fn (self .scale )
792
+ self .zero_point = fn (self .zero_point )
793
+ self .meta = fn (self .meta )
794
+ return self
795
+
796
+
797
+ # Marlin Sparse op dispatch registration
798
+ @MarlinSparseAQTLayout .implements (aten .detach .default )
799
+ def block_sparse_detach (func , types , args , kwargs ):
800
+ return return_and_correct_aliasing (func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach ))
801
+
802
+
645
803
@register_layout_cls (Float8LayoutType )
646
804
class Float8AQTLayout (AQTLayout ):
647
805
"""
@@ -758,7 +916,7 @@ def __repr__(self):
758
916
f"scale={ scale } ,\n "
759
917
f"transposed={ self .transposed } , "
760
918
f"layout_type={ layout_type } )" )
761
-
919
+
762
920
763
921
@register_layout_cls (TensorCoreTiledLayoutType )
764
922
class TensorCoreTiledAQTLayout (AQTLayout ):
@@ -941,6 +1099,7 @@ def _aqt_is_uint4(aqt):
941
1099
aqt .quant_max is None or aqt .quant_max == 15
942
1100
)
943
1101
1102
+
944
1103
implements = AffineQuantizedTensor .implements
945
1104
946
1105
# following are a list of (dispatch_condition, implementation) functions that takes the following args:
@@ -1219,6 +1378,58 @@ def _linear_fp_act_fp8_weight_impl(
1219
1378
).reshape (out_shape )
1220
1379
1221
1380
1381
+ def _linear_fp_act_int4_weight_sparse_marlin_check (input_tensor , weight_tensor , bias ):
1382
+ return (
1383
+ _aqt_is_uint4 (weight_tensor ) and
1384
+ input_tensor .dtype == torch .float16 and
1385
+ len (weight_tensor .shape ) == 2 and
1386
+ weight_tensor .zero_point_domain == ZeroPointDomain .INT and
1387
+ isinstance (weight_tensor .layout_type , MarlinSparseLayoutType )
1388
+ )
1389
+
1390
+ def _linear_fp_act_int4_weight_sparse_marlin_impl (input_tensor , weight_tensor , bias ):
1391
+ from torchao .sparsity .marlin import marlin_24_workspace , const
1392
+
1393
+ sparse_w_int4 = weight_tensor .layout_tensor .int_data
1394
+ scale = weight_tensor .layout_tensor .scale
1395
+ meta = weight_tensor .layout_tensor .meta
1396
+ original_shape = weight_tensor .layout_tensor .original_shape
1397
+ num_bits = weight_tensor .layout_tensor .num_bits
1398
+
1399
+ # Saves batch size for reshaping back to original shape after the matmul
1400
+ # Reshapes tensor to (m, k) where m is in_features * batch and k is out_features
1401
+ # NOTE(reviewers): Please check if I am handling the batch size correctly
1402
+ batch_size = - 1
1403
+ if input_tensor .dim () == 3 :
1404
+ batch_size = input_tensor .size (0 )
1405
+ input_tensor = input_tensor .reshape (- 1 , input_tensor .shape [- 1 ]).contiguous ()
1406
+
1407
+ size_m = input_tensor .shape [0 ]
1408
+ size_n = original_shape [1 ]
1409
+ size_k = input_tensor .shape [1 ]
1410
+ workspace_24 = marlin_24_workspace (original_shape [1 ])
1411
+
1412
+ # Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
1413
+ if size_k % const .TILE != 0 :
1414
+ pad_size = find_multiple (size_k , const .TILE )
1415
+ input_tensor = torch .nn .functional .pad (input_tensor , (0 , pad_size - size_k ))
1416
+ size_k = pad_size
1417
+
1418
+ out = torchao .ops .marlin_24_gemm (
1419
+ input_tensor , sparse_w_int4 , meta , scale ,
1420
+ workspace_24 , num_bits , size_m , size_n , size_k
1421
+ )
1422
+ torch .cuda .synchronize ()
1423
+
1424
+ # Reshape back to original shape
1425
+ if batch_size != - 1 :
1426
+ out = out .reshape (batch_size , - 1 , out .shape [- 1 ])
1427
+
1428
+ if bias is not None :
1429
+ out += bias .to (out .dtype )
1430
+ return out
1431
+
1432
+
1222
1433
def _register_aqt_quantized_linear_dispatches ():
1223
1434
for dispatch_condition , impl in [
1224
1435
(_linear_int8_act_int8_weight_check , _linear_int8_act_int8_weight_impl ),
@@ -1227,6 +1438,7 @@ def _register_aqt_quantized_linear_dispatches():
1227
1438
(_linear_bf16_act_uint4_weight_check , _linear_bf16_act_uint4_weight_impl ),
1228
1439
(_linear_fp_act_int8_weight_check , _linear_fp_act_int8_weight_impl ),
1229
1440
(_linear_f16_act_fpx_weight_check , _linear_f16_act_fpx_weight_impl ),
1441
+ (_linear_fp_act_int4_weight_sparse_marlin_check , _linear_fp_act_int4_weight_sparse_marlin_impl ),
1230
1442
]:
1231
1443
register_aqt_quantized_linear_dispatch (dispatch_condition , impl )
1232
1444
0 commit comments