Skip to content

Commit 65d86c6

Browse files
Diogo-Vjcaip
andauthored
Add sparse marlin AQT layout (pytorch#621)
* feat: starting layout implementation fix: namespace of common modules chore: remove not needed test file fix: op name being registered chore: can compile the cuda kernel fix: segmentation fault chore: wip - paste test code just to check if everything passes feat: wip - adding layout. unpack not working fix: circular import feat: wip - can almost revert feat: can unpack. just needs cleanup chore: improve layout code chore: wip - mm needs work feat: wip - something seems wrong fix: e2e test feat: wip - add group param fix: unpack weights feat: marlin is implemented and correct chore: rebase chore: remove old import feat: use int4 instead of dequantizing chore: remove unused fn feat: add checks and validation feat: add new kernel and refactor code (#1) * feat: wip - adding new kernel * feat: wip - continue working on the unpack * feat: wip - working on unpacking * feat: remove old op * feat: more code changes * chore: remove old code * feat: more code * chore: more code changes * chore: more code changes * feat: add more documentation * fix: dataclass * feat: add more docs * feat: remove assert chore: block 8 bits chore: update comment feat: refactor dispatch chore: add validation on group size chore: wip - working on fixing unpack feat: add small readme with sources feat: add checks feat: tests pass & can execute llama2 * compile kind of working * fix: batching and layout outputs correct results * fix: torch.compile * wip * feat: wip * chore: cleanup * chore: review * chore: review v2 * update benchmarks + README --------- Co-authored-by: Jesse Cai <[email protected]>
1 parent 422301b commit 65d86c6

20 files changed

+538
-102
lines changed

test/sparsity/test_marlin.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import torch
2+
import copy
3+
import pytest
4+
5+
from torch import nn
6+
from torch.testing._internal.common_utils import TestCase, run_tests
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
8+
from torchao.dtypes import MarlinSparseLayoutType
9+
from torchao.sparsity.sparse_api import apply_fake_sparsity
10+
from torchao.quantization.quant_api import int4_weight_only, quantize_
11+
from torchao.sparsity.marlin import (
12+
pack_to_marlin_24,
13+
unpack_from_marlin_24,
14+
inject_24
15+
)
16+
from torchao.quantization.quant_primitives import (
17+
choose_qparams_affine,
18+
quantize_affine,
19+
ZeroPointDomain,
20+
MappingType,
21+
)
22+
23+
24+
class SparseMarlin24(TestCase):
25+
26+
def setUp(self):
27+
super().setUp()
28+
torch.manual_seed(0)
29+
30+
self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
31+
self.model = (
32+
nn.Sequential(
33+
nn.Linear(4096, 21504),
34+
nn.Linear(21504, 4096),
35+
nn.ReLU(),
36+
nn.Linear(4096, 21504),
37+
nn.Linear(21504, 4096),
38+
)
39+
.half()
40+
.cuda()
41+
)
42+
43+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
44+
def test_quant_sparse_marlin_layout_eager(self):
45+
apply_fake_sparsity(self.model)
46+
model_copy = copy.deepcopy(self.model)
47+
48+
# Quantized
49+
quantize_(model_copy.bfloat16(), int4_weight_only())
50+
dense_result = model_copy(self.input.bfloat16()).half()
51+
52+
# Sparse + quantized
53+
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
54+
sparse_result = self.model(self.input)
55+
56+
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
57+
58+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
59+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
60+
def test_quant_sparse_marlin_layout_compile(self):
61+
apply_fake_sparsity(self.model)
62+
model_copy = copy.deepcopy(self.model)
63+
64+
# Quantized
65+
quantize_(model_copy.bfloat16(), int4_weight_only())
66+
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
67+
dense_result = model_copy(self.input.bfloat16()).half()
68+
69+
# Sparse + quantized
70+
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
71+
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
72+
sparse_result = self.model(self.input)
73+
74+
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
75+
76+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
77+
def test_pack_unpack_equivalence(self):
78+
num_bits = 4
79+
group_size = 128
80+
shape = (11008, 4096)
81+
block_size = (1, group_size)
82+
target_dtype = torch.int32
83+
quant_min = 0
84+
quant_max = 15
85+
eps = 1e-6
86+
zero_point_dtype = torch.bfloat16
87+
mapping_type = MappingType.SYMMETRIC
88+
preserve_zero = True
89+
zero_point_domain = ZeroPointDomain.INT
90+
scale_dtype = None
91+
92+
w = torch.rand(shape, dtype=torch.float16, device="cuda")
93+
94+
# Inject 2:4 sparsity mask
95+
w_24, _ = inject_24(w, *w.shape)
96+
97+
# Quantize weights
98+
scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
99+
w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain)
100+
scales = scales.reshape(-1, w_q_24.shape[1])
101+
102+
# Test pack/unpack equivalence
103+
q_w_comp, packed_scales, meta = pack_to_marlin_24(
104+
w_q_24, scales, num_bits, group_size
105+
)
106+
unpacked_q_w, unpacked_scales = unpack_from_marlin_24(
107+
q_w_comp, packed_scales, meta, shape, group_size, num_bits
108+
)
109+
110+
assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights"
111+
assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales"
112+
113+
114+
if __name__ == "__main__":
115+
run_tests()

test/sparsity/test_sparse_api.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
int8_dynamic_activation_int8_semi_sparse_weight,
1212
semi_sparse_weight,
1313
)
14+
from torchao.dtypes import MarlinSparseLayoutType
1415
from torchao.quantization.quant_api import (
15-
_replace_with_custom_fn_if_matches_filter,
16-
_get_subclass_inserter,
17-
_is_linear,
1816
int8_dynamic_activation_int8_weight,
1917
quantize_,
18+
int4_weight_only,
2019
)
2120
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
2221
from torch.testing._internal.common_utils import TestCase
@@ -73,5 +72,30 @@ def test_quant_semi_sparse(self):
7372

7473
assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
7574

75+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
76+
def test_sparse_marlin(self):
77+
input = torch.rand((256, 256)).half().cuda()
78+
model = (
79+
nn.Sequential(
80+
nn.Linear(256, 1024),
81+
nn.Linear(1024, 256),
82+
)
83+
.half()
84+
.cuda()
85+
)
86+
87+
apply_fake_sparsity(model)
88+
model_copy = copy.deepcopy(model)
89+
90+
# Quantized
91+
quantize_(model_copy.bfloat16(), int4_weight_only())
92+
dense_result = model_copy(input.bfloat16()).half()
93+
94+
# Sparse + quantized
95+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
96+
sparse_result = model(input)
97+
98+
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
99+
76100
if __name__ == "__main__":
77101
unittest.main()

test/test_ops.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
304304
)
305305

306306

307+
MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
307308
MARLIN_24_K_CHUNKS = [128]
308309
MARLIN_24_N_CHUNKS = [512]
309310
MNK_FACTORS = [
@@ -318,8 +319,8 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
318319
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
319320

320321
MARLIN_TEST_PARAMS = list(itertools.product(
321-
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
322-
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
322+
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
323+
MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
323324
))
324325

325326
def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int):
@@ -374,15 +375,15 @@ def reshape_w(w):
374375
)
375376

376377
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
377-
@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
378-
def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
378+
@pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
379+
def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors):
379380
m_factor, n_factor, k_factor = mnk_factors
380381

381382
size_m = m_factor
382383
size_k = k_chunk * k_factor
383384
size_n = n_chunk * n_factor
384385

385-
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
386+
a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda")
386387
b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda")
387388

388389
# Inject 2:4 sparsity
@@ -391,19 +392,24 @@ def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
391392
# Symmetric quantize
392393
w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size)
393394

395+
# Reshape input into 2D tensor
396+
input_2d = a_input.view(-1, a_input.shape[-1])
397+
a_input_in, a_input_out = input_2d.shape
398+
394399
# Obtains reference output
395-
output_ref = torch.matmul(a_input, w_24_ref)
400+
output_ref = torch.matmul(input_2d, w_24_ref)
401+
output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],))
396402

397403
# Packs to marlin 2:4
398404
marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
399405
workspace_24 = marlin_24_workspace(size_n)
400406

401407
fn_inputs = (
402-
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
403-
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
408+
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
409+
num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out,
404410
)
405411
output = torchao.ops.marlin_24_gemm(*fn_inputs)
406-
torch.cuda.synchronize()
412+
output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],))
407413

408414
max_diff = compute_max_diff(output, output_ref)
409415
assert max_diff < 0.04

torchao/_models/llama/benchmark_results.txt

+1
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ kv cache quantization:
3838
20240826171015, tok/s= 1.95, mem/s= 29.21 GB/s, peak_mem=59.27 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072
3939
20240826172121, tok/s= 1.73, mem/s= 26.02 GB/s, peak_mem=52.62 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization
4040
20240826173230, tok/s= 1.73, mem/s= 25.95 GB/s, peak_mem=34.18 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization --linear_causal_mask
41+
20240906054415, tok/s=226.02, mem/s= 689.20 GB/s, peak_mem= 5.32 GB, model_size= 3.05 GB quant: marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

torchao/_models/llama/benchmarks.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
3030
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
3131
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
3232
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt
33-
33+
# sparse marlin (NOTE: float16)
34+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
3435
# auto-round w/ quant_lm_head
3536
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround
3637
# auto-round w/o quant_lm_head
3738
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0
3839

3940

41+
4042
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
4143
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192
4244
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization

torchao/_models/llama/generate.py

+3
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ def main(
225225
groupsize=int(quantization.split("-")[-1])
226226
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
227227
quantize_(model, int4_weight_only(group_size=groupsize))
228+
if "marlin" in quantization:
229+
from torchao.dtypes import MarlinSparseLayoutType
230+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
228231
if "autoround" in quantization:
229232
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
230233
from transformers import AutoTokenizer

torchao/_models/sam/benchmark.sh

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017
88
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
99
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
1010
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
11+
# int8 dynamic quant attn + int4 wo + sparse marlin lin 1 + 2:4 sparse lin2
12+
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half float16 --device cuda --compress int4_weight_only_sparse

torchao/_models/sam/eval_combo.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,16 @@ def run(
283283
for block in predictor.model.image_encoder.blocks:
284284
block.attn.use_rel_pos = use_rel_pos
285285

286+
# Helper filter functions
287+
def attn_only(mod, name):
288+
return isinstance(mod, torch.nn.Linear) and 'attn' in name
289+
def mlp_lin1_only(mod, name):
290+
return isinstance(mod, torch.nn.Linear) and 'lin1' in name
291+
def mlp_lin2_only(mod, name):
292+
return isinstance(mod, torch.nn.Linear) and 'lin2' in name
293+
def mlp_only(mod, name):
294+
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
295+
286296
if compress == "int8_dynamic_quant":
287297
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
288298
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -296,15 +306,6 @@ def mlp_only(mod, name):
296306
apply_fake_sparsity(predictor.model.image_encoder)
297307
sparsify_(predictor.model.image_encoder, semi_sparse_weight())
298308
elif compress == "int8_dynamic_quant_sparse":
299-
def attn_only(mod, name):
300-
return isinstance(mod, torch.nn.Linear) and 'attn' in name
301-
def mlp_lin1_only(mod, name):
302-
return isinstance(mod, torch.nn.Linear) and 'lin1' in name
303-
def mlp_lin2_only(mod, name):
304-
return isinstance(mod, torch.nn.Linear) and 'lin2' in name
305-
def mlp_only(mod, name):
306-
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
307-
308309
# apply sparsify first to set qparams
309310
apply_fake_sparsity(predictor.model.image_encoder,
310311
filter_fn=mlp_only)
@@ -320,7 +321,20 @@ def mlp_only(mod, name):
320321
mlp_lin2_only)
321322
if not TORCH_VERSION_AT_LEAST_2_5:
322323
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
323-
324+
elif compress == "int4_weight_only_sparse":
325+
# apply sparsify first to set qparams
326+
apply_fake_sparsity(predictor.model.image_encoder,
327+
filter_fn=mlp_only)
328+
from torchao.dtypes import MarlinSparseLayoutType
329+
quantize_(predictor.model.image_encoder,
330+
int8_dynamic_activation_int8_weight(),
331+
attn_only)
332+
quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayoutType()), mlp_lin1_only)
333+
sparsify_(predictor.model.image_encoder
334+
semi_sparse_weight(),
335+
mlp_lin2_only)
336+
if not TORCH_VERSION_AT_LEAST_2_5:
337+
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
324338
else:
325339
assert compress is None, f"Unsupported compress mode {compress}"
326340

torchao/_models/sam/results.csv

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,ma
44
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
55
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
66
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
7+
cuda,vit_h,32,17068,21,23.96093702681232,41.73459489004953,0.5485481164943489,max-autotune,torch.float16,int4_weight_only_sparse,False,True,True,32,154,4928,None,None

torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -1123,4 +1123,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
11231123
m.impl("torchao::marlin_24_gemm", &marlin_24_gemm);
11241124
}
11251125

1126-
} // namespace torchao
1126+
} // namespace torchao

torchao/csrc/sparse_marlin.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
TORCH_LIBRARY_FRAGMENT(torchao, m) {
66
m.impl_abstract_pystub("torchao.ops");
77
m.def("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor");
8-
}
8+
}

torchao/dtypes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TensorCoreTiledLayoutType,
1616
Float8LayoutType,
1717
Float8AQTLayout,
18+
MarlinSparseLayoutType,
1819
)
1920

2021
__all__ = [
@@ -33,4 +34,5 @@
3334
"TensorCoreTiledLayoutType",
3435
"Float8LayoutType",
3536
"Float8AQTLayout",
37+
"MarlinSparseLayoutType",
3638
]

0 commit comments

Comments
 (0)