Skip to content

Commit fdad96f

Browse files
committed
feat: starting layout implementation
fix: namespace of common modules chore: remove not needed test file fix: op name being registered chore: can compile the cuda kernel fix: segmentation fault chore: wip - paste test code just to check if everything passes feat: wip - adding layout. unpack not working fix: circular import feat: wip - can almost revert feat: can unpack. just needs cleanup chore: improve layout code chore: wip - mm needs work feat: wip - something seems wrong fix: e2e test feat: wip - add group param fix: unpack weights feat: marlin is implemented and correct chore: rebase chore: remove old import feat: use int4 instead of dequantizing chore: remove unused fn feat: add checks and validation feat: add new kernel and refactor code (#1) * feat: wip - adding new kernel * feat: wip - continue working on the unpack * feat: wip - working on unpacking * feat: remove old op * feat: more code changes * chore: remove old code * feat: more code * chore: more code changes * chore: more code changes * feat: add more documentation * fix: dataclass * feat: add more docs * feat: remove assert chore: block 8 bits chore: update comment feat: refactor dispatch chore: add validation on group size chore: wip - working on fixing unpack feat: add small readme with sources feat: add checks feat: tests pass & can execute llama2
1 parent a246d87 commit fdad96f

File tree

11 files changed

+342
-10
lines changed

11 files changed

+342
-10
lines changed

test/sparsity/test_marlin.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
import copy
3+
import pytest
4+
5+
from torch import nn
6+
from torch.testing._internal.common_utils import TestCase, run_tests
7+
from torchao.dtypes import MarlinSparseLayoutType
8+
from torchao.sparsity.sparse_api import apply_fake_sparsity
9+
from torchao.quantization.quant_api import int4_weight_only, quantize_
10+
from torchao.sparsity.marlin import (
11+
pack_to_marlin_24,
12+
unpack_from_marlin_24,
13+
inject_24
14+
)
15+
from torchao.quantization.utils import (
16+
get_group_qparams_symmetric,
17+
groupwise_affine_quantize_tensor_from_qparams,
18+
)
19+
20+
21+
class SparseMarlin24(TestCase):
22+
23+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
24+
def test_quant_sparse_marlin_layout_e2e(self):
25+
input = torch.randn((16, 4096), dtype=torch.float16, device="cuda")
26+
model = (
27+
nn.Sequential(
28+
nn.Linear(4096, 11008), # Llama2 shapes
29+
nn.Linear(11008, 4096),
30+
nn.ReLU(),
31+
nn.Linear(4096, 11008),
32+
nn.Linear(11008, 4096),
33+
)
34+
.half()
35+
.cuda()
36+
)
37+
38+
# Baseline
39+
ref_result = model(input)
40+
41+
apply_fake_sparsity(model)
42+
model_copy = copy.deepcopy(model)
43+
44+
# Quantized
45+
quantize_(model_copy.bfloat16(), int4_weight_only())
46+
dense_result = model_copy(input.bfloat16()).half()
47+
48+
# Sparse + quantized
49+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
50+
sparse_result = model(input)
51+
52+
error_dense = torch.mean(torch.abs(ref_result - dense_result) ** 2)
53+
error_sparse = torch.mean(torch.abs(ref_result - sparse_result) ** 2)
54+
assert torch.allclose(error_dense, error_sparse, atol=1e-2), "Mean error is not close"
55+
56+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
57+
def test_pack_unpack_equivalence(self):
58+
num_bits = 4
59+
group_size = 128
60+
shape = (11008, 4096)
61+
w = torch.rand(shape, dtype=torch.float16, device="cuda")
62+
63+
# Inject 2:4 sparsity mask
64+
w_24, _ = inject_24(w, *w.shape)
65+
66+
# Quantize weights
67+
scales, zeros = get_group_qparams_symmetric(w_24, n_bit=4, groupsize=group_size)
68+
w_q_24 = groupwise_affine_quantize_tensor_from_qparams(
69+
w_24, scales, zeros, n_bit=4, groupsize=group_size
70+
)
71+
72+
scales = scales.reshape(-1, w_q_24.shape[1])
73+
74+
# Test pack/unpack equivalence
75+
q_w_comp, packed_scales, meta = pack_to_marlin_24(
76+
w_q_24, scales, num_bits, group_size
77+
)
78+
unpacked_q_w, unpacked_scales = unpack_from_marlin_24(
79+
q_w_comp, packed_scales, meta, shape, group_size, num_bits
80+
)
81+
82+
assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights"
83+
assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales"
84+
85+
86+
if __name__ == "__main__":
87+
run_tests()

torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -1123,4 +1123,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
11231123
m.impl("torchao::marlin_24_gemm", &marlin_24_gemm);
11241124
}
11251125

1126-
} // namespace torchao
1126+
} // namespace torchao

torchao/csrc/sparse_marlin.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
TORCH_LIBRARY_FRAGMENT(torchao, m) {
66
m.impl_abstract_pystub("torchao.ops");
77
m.def("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor");
8-
}
8+
}

torchao/dtypes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TensorCoreTiledLayoutType,
1616
Float8LayoutType,
1717
Float8AQTLayout,
18+
MarlinSparseLayoutType,
1819
)
1920

2021
__all__ = [
@@ -33,4 +34,5 @@
3334
"TensorCoreTiledLayoutType",
3435
"Float8LayoutType",
3536
"Float8AQTLayout",
37+
"MarlinSparseLayoutType",
3638
]

torchao/dtypes/affine_quantized_tensor.py

+215-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
2-
from typing import Dict, Callable, Any, Tuple, Optional, Union
2+
from typing import Tuple, Optional, Union
3+
import torchao.ops
34
from collections import defaultdict
45
import functools
56
import math
@@ -39,7 +40,7 @@
3940
logger = logging.getLogger(__name__)
4041

4142
from torchao.float8.inference import Float8MMConfig
42-
aten = torch.ops.aten
43+
4344

4445
###############################
4546
# Base Layout Tensor Subclass #
@@ -489,6 +490,16 @@ class Float8LayoutType(LayoutType):
489490
mm_config: Optional[Float8MMConfig] = None
490491

491492

493+
@dataclass(frozen=True)
494+
class MarlinSparseLayoutType(LayoutType):
495+
496+
# Inject 2:4 sparsity
497+
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
498+
from torchao.sparsity.marlin import inject_24 # avoid circular import
499+
w_24, _ = inject_24(input, *input.shape)
500+
return w_24
501+
502+
492503
@register_layout_cls(PlainLayoutType)
493504
class PlainAQTLayout(AQTLayout):
494505
"""
@@ -642,6 +653,153 @@ def from_plain(
642653
return cls(int_data_compressed, scale, zero_point, layout_type)
643654

644655

656+
@register_layout_cls(MarlinSparseLayoutType)
657+
class MarlinSparseAQTLayout(AQTLayout):
658+
"""
659+
Layout storage class for sparse_marlin_24 layout for affine quantized tensor.
660+
661+
Can be used with 4 bits and 8 bits quantization.
662+
663+
Original marlin documentation and information:
664+
https://github.com/IST-DASLab/marlin/tree/master
665+
666+
Sparse marlin documentation and information:
667+
https://github.com/IST-DASLab/Sparse-Marlin?tab=readme-ov-file
668+
669+
fields:
670+
original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape
671+
group_size (int): the group size used to pack the tensor
672+
num_bits (int): the number of bits used to quantize the tensor
673+
"""
674+
675+
implements = classmethod(_implements)
676+
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
677+
__torch_function__ = classmethod(_dispatch__torch_function__)
678+
679+
def __new__(
680+
cls,
681+
int_data: torch.Tensor,
682+
scale: torch.Tensor,
683+
zero_point: torch.Tensor,
684+
meta: torch.Tensor,
685+
layout_type: LayoutType,
686+
original_shape: torch.Size,
687+
group_size: int,
688+
num_bits: int,
689+
):
690+
kwargs = {}
691+
kwargs["device"] = int_data.device
692+
kwargs["layout"] = (
693+
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
694+
)
695+
kwargs["dtype"] = int_data.dtype
696+
kwargs["requires_grad"] = False
697+
shape = int_data.shape
698+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
699+
700+
def __init__(
701+
self,
702+
int_data: torch.Tensor,
703+
scale: torch.Tensor,
704+
zero_point: torch.Tensor,
705+
meta: torch.Tensor,
706+
layout_type: LayoutType,
707+
original_shape: torch.Size,
708+
group_size: int,
709+
num_bits: int,
710+
):
711+
self.int_data = int_data
712+
self.scale = scale
713+
self.zero_point = zero_point
714+
self.meta = meta
715+
self.layout_type = layout_type
716+
self.original_shape = original_shape
717+
self.group_size = group_size
718+
self.num_bits = num_bits
719+
720+
def get_plain(self):
721+
from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import
722+
int_data_expanded, scales_expanded = unpack_from_marlin_24(
723+
self.int_data,
724+
self.scale,
725+
self.meta,
726+
self.original_shape,
727+
self.group_size,
728+
self.num_bits,
729+
)
730+
return int_data_expanded, scales_expanded, self.zero_point
731+
732+
@classmethod
733+
def from_plain(
734+
cls,
735+
int_data: torch.Tensor,
736+
scale: torch.Tensor,
737+
zero_point: torch.Tensor,
738+
layout_type: LayoutType,
739+
):
740+
from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import
741+
assert isinstance(layout_type, MarlinSparseLayoutType)
742+
743+
# Linear layers are (in_features, out_features) but the int_data that is reaching this point
744+
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
745+
# NOTE(reviewers): Please check if this is what I should do.
746+
q_w_24 = int_data.t()
747+
scale = scale.reshape(-1, q_w_24.shape[1])
748+
749+
if q_w_24.dtype != torch.int32:
750+
raise ValueError("Only `torch.int32` weights are supported.")
751+
752+
in_features, out_features = q_w_24.shape
753+
if in_features % 128 != 0 or out_features != 256 == 0:
754+
raise ValueError(
755+
"`in_features` must be divisible by 64 and `out_features` by 256."
756+
)
757+
758+
# NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
759+
# will require a bit more work to get our current quantization flow to work with it.
760+
# Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
761+
num_bits = 4 if torch.max(q_w_24) < 16 else -1
762+
if num_bits not in [4]:
763+
raise ValueError(
764+
f"Only {[4]} bits are supported, got {num_bits}."
765+
)
766+
767+
group_size = in_features // scale.shape[0]
768+
if group_size == 0:
769+
group_size = in_features
770+
assert group_size <= in_features, "Group size must be less than or equal to in_features."
771+
772+
if group_size not in const.SUPPORTED_GROUP_SIZES:
773+
raise ValueError(
774+
f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}."
775+
)
776+
777+
# Compress quantized weight to marlin 2:4 format
778+
marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
779+
780+
return cls(
781+
marlin_24_q_w_comp, marlin_24_s, zero_point,
782+
meta, layout_type, q_w_24.shape,
783+
group_size, num_bits
784+
)
785+
786+
def get_layout_type(self) -> LayoutType:
787+
return self.layout_type
788+
789+
def _apply_fn_to_data(self, fn):
790+
self.int_data = fn(self.int_data)
791+
self.scale = fn(self.scale)
792+
self.zero_point = fn(self.zero_point)
793+
self.meta = fn(self.meta)
794+
return self
795+
796+
797+
# Marlin Sparse op dispatch registration
798+
@MarlinSparseAQTLayout.implements(aten.detach.default)
799+
def block_sparse_detach(func, types, args, kwargs):
800+
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
801+
802+
645803
@register_layout_cls(Float8LayoutType)
646804
class Float8AQTLayout(AQTLayout):
647805
"""
@@ -758,7 +916,7 @@ def __repr__(self):
758916
f"scale={scale},\n"
759917
f"transposed={self.transposed}, "
760918
f"layout_type={layout_type})")
761-
919+
762920

763921
@register_layout_cls(TensorCoreTiledLayoutType)
764922
class TensorCoreTiledAQTLayout(AQTLayout):
@@ -941,6 +1099,7 @@ def _aqt_is_uint4(aqt):
9411099
aqt.quant_max is None or aqt.quant_max == 15
9421100
)
9431101

1102+
9441103
implements = AffineQuantizedTensor.implements
9451104

9461105
# following are a list of (dispatch_condition, implementation) functions that takes the following args:
@@ -1219,6 +1378,58 @@ def _linear_fp_act_fp8_weight_impl(
12191378
).reshape(out_shape)
12201379

12211380

1381+
def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
1382+
return (
1383+
_aqt_is_uint4(weight_tensor) and
1384+
input_tensor.dtype == torch.float16 and
1385+
len(weight_tensor.shape) == 2 and
1386+
weight_tensor.zero_point_domain == ZeroPointDomain.INT and
1387+
isinstance(weight_tensor.layout_type, MarlinSparseLayoutType)
1388+
)
1389+
1390+
def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
1391+
from torchao.sparsity.marlin import marlin_24_workspace, const
1392+
1393+
sparse_w_int4 = weight_tensor.layout_tensor.int_data
1394+
scale = weight_tensor.layout_tensor.scale
1395+
meta = weight_tensor.layout_tensor.meta
1396+
original_shape = weight_tensor.layout_tensor.original_shape
1397+
num_bits = weight_tensor.layout_tensor.num_bits
1398+
1399+
# Saves batch size for reshaping back to original shape after the matmul
1400+
# Reshapes tensor to (m, k) where m is in_features * batch and k is out_features
1401+
# NOTE(reviewers): Please check if I am handling the batch size correctly
1402+
batch_size = -1
1403+
if input_tensor.dim() == 3:
1404+
batch_size = input_tensor.size(0)
1405+
input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]).contiguous()
1406+
1407+
size_m = input_tensor.shape[0]
1408+
size_n = original_shape[1]
1409+
size_k = input_tensor.shape[1]
1410+
workspace_24 = marlin_24_workspace(original_shape[1])
1411+
1412+
# Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
1413+
if size_k % const.TILE != 0:
1414+
pad_size = find_multiple(size_k, const.TILE)
1415+
input_tensor = torch.nn.functional.pad(input_tensor, (0, pad_size - size_k))
1416+
size_k = pad_size
1417+
1418+
out = torchao.ops.marlin_24_gemm(
1419+
input_tensor, sparse_w_int4, meta, scale,
1420+
workspace_24, num_bits, size_m, size_n, size_k
1421+
)
1422+
torch.cuda.synchronize()
1423+
1424+
# Reshape back to original shape
1425+
if batch_size != -1:
1426+
out = out.reshape(batch_size, -1, out.shape[-1])
1427+
1428+
if bias is not None:
1429+
out += bias.to(out.dtype)
1430+
return out
1431+
1432+
12221433
def _register_aqt_quantized_linear_dispatches():
12231434
for dispatch_condition, impl in [
12241435
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
@@ -1227,6 +1438,7 @@ def _register_aqt_quantized_linear_dispatches():
12271438
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
12281439
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
12291440
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),
1441+
(_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl),
12301442
]:
12311443
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
12321444

0 commit comments

Comments
 (0)