Skip to content

Commit 7948b95

Browse files
Diogo-Vjcaip
authored andcommitted
feat: add new kernel and refactor code (#1)
* feat: wip - adding new kernel * feat: wip - continue working on the unpack * feat: wip - working on unpacking * feat: remove old op * feat: more code changes * chore: remove old code * feat: more code * chore: more code changes * chore: more code changes * feat: add more documentation * fix: dataclass * feat: add more docs * feat: remove assert
1 parent 08ab816 commit 7948b95

File tree

14 files changed

+1502
-912
lines changed

14 files changed

+1502
-912
lines changed

test/sparsity/test_marlin.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from torchao.sparsity.sparse_api import apply_fake_sparsity
99
from torchao.quantization.quant_api import int4_weight_only, quantize_
1010
from torchao.sparsity.marlin import (
11-
pack_to_sparse_marlin_24,
12-
unpack_from_sparse_marlin_24,
11+
pack_to_marlin_24,
12+
unpack_from_marlin_24,
13+
inject_24
1314
)
1415

1516

@@ -24,41 +25,48 @@ def test_quant_sparse_marlin_layout_e2e(self):
2425
nn.Linear(21504, 256),
2526
nn.ReLU(),
2627
nn.Linear(256, 128),
28+
nn.ReLU(),
29+
nn.Linear(128, 4096),
2730
)
2831
.half()
2932
.cuda()
3033
)
3134

35+
# Baseline
36+
ref_result = model(input)
37+
3238
apply_fake_sparsity(model)
3339
model_copy = copy.deepcopy(model)
3440

35-
# Baseline to match against
41+
# Quantized
3642
quantize_(model_copy.bfloat16(), int4_weight_only())
3743
dense_result = model_copy(input.bfloat16()).half()
3844

3945
# Sparse + quantized
4046
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
4147
sparse_result = model(input)
4248

43-
assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2), "Sparse and dense results do not match"
49+
error_dense = torch.mean(torch.abs(ref_result - dense_result) ** 2)
50+
error_sparse = torch.mean(torch.abs(ref_result - sparse_result) ** 2)
51+
assert torch.allclose(error_dense, error_sparse, atol=1e-3), "Mean error is not close"
4452

4553
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
4654
def test_pack_unpack_equivalence(self):
47-
tiles = 16
55+
num_bits = 4
56+
group_size = 128
4857
shape = (512, 4096)
49-
w_int4 = torch.randint(0, 15, shape).int().cuda()
58+
w_q = torch.randint(0, 15, shape).int().cuda()
5059
scales = torch.rand(4096).cuda()
5160

52-
# Test pack/unpack equivalence
53-
sparse_w_int4, packed_scales, meta = pack_to_sparse_marlin_24(w_int4, scales, tiles)
54-
unpacked_w_int4, unpacked_scales = unpack_from_sparse_marlin_24(sparse_w_int4, packed_scales, meta, tiles, shape)
61+
w_q_24, _ = inject_24(w_q, *w_q.shape)
5562

56-
# When unpacking, that values that were masked will be zeroed out. So, we need
57-
# to zero out the same values in the original weights to compare
58-
makeshift_mask = unpacked_w_int4 == 0
59-
w_int4[makeshift_mask] = 0
63+
# Test pack/unpack equivalence
64+
q_w_comp, packed_scales, meta = pack_to_marlin_24(w_q_24, scales, num_bits, group_size)
65+
unpacked_q_w, unpacked_scales = unpack_from_marlin_24(
66+
q_w_comp, packed_scales, meta, shape, group_size, num_bits
67+
)
6068

61-
assert torch.equal(w_int4, unpacked_w_int4), "Unpacked weights do not match original weights"
69+
assert torch.equal(w_q, unpacked_q_w), "Unpacked weights do not match original weights"
6270
assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales"
6371

6472

test/test_ops.py

+91-118
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
run_tests,
1111
)
1212
from torch.testing._internal.optests import opcheck
13-
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5
13+
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
1414
from torchao.prototype.quant_llm import from_scaled_tc_fpx
15+
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
1516
import pytest
1617

1718
if is_fbcode():
@@ -22,12 +23,6 @@
2223
except RuntimeError:
2324
pytest.skip("torchao.ops not available")
2425

25-
from torchao.sparsity.utils import mask_creator
26-
from torchao.sparsity.marlin import (
27-
pack_to_sparse_marlin_24,
28-
marlin_24_mm,
29-
fp16_to_int4_marlin_format
30-
)
3126
from torchao.quantization.utils import (
3227
get_groupwise_affine_qparams,
3328
groupwise_affine_dequantize_tensor_from_qparams,
@@ -309,139 +304,117 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
309304
)
310305

311306

312-
class SparseMarlin24(TestCase):
313-
TILES = 16
307+
MARLIN_24_K_CHUNKS = [128]
308+
MARLIN_24_N_CHUNKS = [512]
309+
MNK_FACTORS = [
310+
(1, 1, 1),
311+
(1, 4, 8),
312+
(1, 7, 5),
313+
(13, 17, 67),
314+
(26, 37, 13),
315+
(67, 13, 11),
316+
]
317+
MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
318+
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
314319

315-
def _op_check(self, inputs, sparse_w_int4, meta, scales, workspace, thread_k, thread_m, sms=-1, max_par=16):
316-
out = torch.empty((inputs.size(0), scales.size(1)), dtype=inputs.dtype, device=inputs.device)
320+
MARLIN_TEST_PARAMS = list(itertools.product(
321+
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
322+
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
323+
))
317324

318-
prob_n = inputs.size(0)
319-
prob_m = out.size(1)
320-
prob_k = inputs.size(1)
321-
group_size = -1 if scales.size(0) == 1 else int(prob_k / 2 / scales.size(0))
322-
device = torch.cuda.current_device()
325+
def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int):
326+
orig_device = w.device
327+
size_k, size_n = w.shape
323328

324-
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
325-
opcheck(
326-
torch.ops.torchao.marlin_24_mm,
327-
(
328-
inputs, sparse_w_int4, meta, out, scales, prob_m, prob_n, prob_k,
329-
workspace, group_size, device, thread_k, thread_m, sms, max_par
330-
),
331-
test_utils=test_utils,
332-
)
333-
334-
def _gen_values(self, m, n, k, group_size):
335-
maxq = 2**4 - 1
336-
inputs = torch.randn((n, k), dtype=torch.half, device="cuda")
337-
w = torch.randn((m, k), dtype=torch.half, device="cuda")
338-
339-
w = w.t()
340-
if group_size != -1:
341-
w = w.reshape((-1, group_size, m))
342-
w = w.permute(1, 0, 2)
343-
w = w.reshape((group_size, -1))
329+
assert w.is_floating_point(), "w must be float"
344330

345-
scales = torch.max(torch.abs(w), 0, keepdim=True)[0]
346-
scales *= 2 / maxq
331+
if group_size == -1:
332+
group_size = size_k
333+
assert group_size <= size_k
347334

348-
w = torch.round(w / scales).int()
349-
w += (maxq + 1) // 2
350-
w = torch.clamp(w, 0, maxq)
335+
max_q_val = 2**num_bits - 1
336+
half_q_val = (max_q_val + 1) // 2
351337

352-
w_fp16 = (w - (maxq + 1) // 2).half() * scales
353-
scales = scales.reshape((-1, m)).contiguous()
338+
# Reshape to [groupsize, -1]
339+
if group_size < size_k:
340+
w = w.reshape((-1, group_size, size_n))
341+
w = w.permute(1, 0, 2)
342+
w = w.reshape((group_size, -1))
354343

355-
if group_size != -1:
344+
# Compute scale for each group
345+
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
346+
s *= 2 / max_q_val # 2 => symmetric
356347

357-
def reshape(w):
358-
w = w.reshape((group_size, -1, m))
359-
w = w.permute(1, 0, 2)
360-
w = w.reshape((k, m)).contiguous()
361-
return w
348+
# Quantize
349+
q_w = torch.round(w / s).int()
350+
q_w += half_q_val
351+
q_w = torch.clamp(q_w, 0, max_q_val)
362352

363-
w_fp16 = reshape(w_fp16)
364-
w = reshape(w)
365-
366-
mask = mask_creator(w.T).cuda().bool()
367-
sparse_w_fp16_ref = (mask * w_fp16.T).T
353+
# Compute ref (dequantized)
354+
w_ref = (q_w - half_q_val).half() * s
368355

369-
return inputs, sparse_w_fp16_ref, w_fp16, scales
356+
# Restore original shapes
357+
if group_size < size_k:
370358

371-
def _run_problem(self, m, n, k, thread_k, thread_m, group_size=-1):
372-
inputs, sparse_w_fp16_ref, w_fp16, scales = self._gen_values(m, n, k, group_size)
373-
out_ref = torch.matmul(inputs, sparse_w_fp16_ref)
359+
def reshape_w(w):
360+
w = w.reshape((group_size, -1, size_n))
361+
w = w.permute(1, 0, 2)
362+
w = w.reshape((size_k, size_n)).contiguous()
363+
return w
374364

375-
# If no groupsize is provided, we assume it is the same as the in_features of the weights
376-
# https://github.com/IST-DASLab/Sparse-Marlin/blob/c2ffa2395a3ada26c8cb7f910a5ec65bd3ce288a/marlin/__init__.py#L290
377-
if group_size == -1:
378-
group_size = k
365+
q_w = reshape_w(q_w)
366+
w_ref = reshape_w(w_ref)
379367

380-
w_int4, scales = fp16_to_int4_marlin_format(w_fp16, scales, group_size)
381-
sparse_w_int4, scales, meta = pack_to_sparse_marlin_24(w_int4, scales, self.TILES)
368+
s = s.reshape((-1, size_n)).contiguous()
382369

383-
workspace = torch.zeros(m // 128 * 16, device="cuda", dtype=torch.int32)
384-
out = marlin_24_mm(inputs, sparse_w_int4, meta, scales, workspace, thread_k, thread_m, -1)
385-
torch.cuda.synchronize()
370+
return (
371+
w_ref.to(device=orig_device),
372+
q_w.to(device=orig_device),
373+
s.to(device=orig_device),
374+
)
386375

387-
self.assertLess(
388-
torch.mean(torch.abs(out - out_ref)) / torch.mean(torch.abs(out_ref)), 0.002
389-
)
376+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
377+
@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
378+
def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
379+
m_factor, n_factor, k_factor = mnk_factors
390380

391-
# TODO(diogo): Enable this check once I understand how to make `out` mutable
392-
# self._op_check(inputs, sparse_w_int4, meta, scales, workspace, thread_k, thread_m)
381+
size_m = m_factor
382+
size_k = k_chunk * k_factor
383+
size_n = n_chunk * n_factor
393384

394-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
395-
def test_correctness(self):
396-
self._run_problem(256, 16, 256, 128, 128, -1)
397-
self._run_problem(21504, 16, 4096, 64, 256, 128)
385+
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
386+
b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda")
398387

399-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
400-
def test_tiles(self):
401-
for m in [1, 2, 4, 8, 12, 16, 32, 64]:
402-
for thread_k, thread_n in [(64, 256), (128, 128)]:
403-
if m > 16 and thread_k == 128:
404-
continue
405-
self._run_problem(2 * 256, m, 1024, thread_k, thread_n)
388+
# Inject 2:4 sparsity
389+
w_24, _ = inject_24(b_weight, size_k, size_n)
406390

407-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
408-
def test_k_stages_divisibility(self):
409-
for k in [3 * 64 + 64 * 4 * 2 + 64 * i for i in range(1, 4)]:
410-
self._run_problem(2 * 256, 16, k, 64, 256)
391+
# Symmetric quantize
392+
w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size)
411393

412-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
413-
def test_very_few_stages(self):
414-
for k in [64, 128, 192]:
415-
self._run_problem(3 * 256, 16, k, 64, 256)
394+
# Obtains reference output
395+
output_ref = torch.matmul(a_input, w_24_ref)
416396

417-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
418-
def test_llama_shapes(self):
419-
MODELS = {
420-
" 7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
421-
"13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
422-
"33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],
423-
"70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],
424-
}
425-
426-
try:
427-
for _, layers in MODELS.items():
428-
for layer in layers:
429-
for thread_k, thread_m in [(128, 128)]:
430-
for batch in [16]:
431-
print(layer[1], batch, layer[0])
432-
self._run_problem(layer[1], batch, layer[0], thread_k, thread_m)
433-
# If someone runs this on a GPU with less than 24GB of memory, it will run out of memory
434-
# but we don't want to fail the test
435-
except torch.OutOfMemoryError:
436-
pass
397+
# Packs to marlin 2:4
398+
marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
399+
workspace_24 = marlin_24_workspace(size_n)
437400

438-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
439-
def test_groups(self):
440-
for m in [16]:
441-
for groupsize in [128]:
442-
for n, k in [(256, 512), (256, 1024), (256 * 128, 1024)]:
443-
for thread_shape in [(128, 128), (64, 256)]:
444-
self._run_problem(n, m, k, *thread_shape, groupsize)
401+
fn_inputs = (
402+
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
403+
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
404+
)
405+
output = torchao.ops.marlin_24_gemm(*fn_inputs)
406+
torch.cuda.synchronize()
407+
408+
max_diff = compute_max_diff(output, output_ref)
409+
assert max_diff < 0.04
410+
411+
# Performs opcheck
412+
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
413+
opcheck(
414+
torch.ops.torchao.marlin_24_gemm,
415+
fn_inputs,
416+
test_utils=test_utils,
417+
)
445418

446419

447420
if __name__ == "__main__":

torchao/csrc/cuda/sparse_marlin/base.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
2626
// corresponding index accesses must be compile-time constants, which is why we
2727
// extensively use `#pragma unroll` throughout the kernel code to guarantee
2828
// this.
29-
template <typename T, int n> struct Vec {
29+
template <typename T, int n>
30+
struct Vec {
3031
T elems[n];
31-
__device__ T &operator[](int i) { return elems[i]; }
32+
__device__ T& operator[](int i) { return elems[i]; }
3233
};
3334

34-
template <int M_, int N_, int K_> struct ShapeBase {
35+
template <int M_, int N_, int K_>
36+
struct ShapeBase {
3537
static constexpr int M = M_, N = N_, K = K_;
3638
};
3739

@@ -44,6 +46,6 @@ using FragA = Vec<half2, 4>;
4446
using FragB = Vec<half2, 2>;
4547
using FragM = Vec<uint, 1>;
4648
using FragC = Vec<float, 4>;
47-
using FragS = Vec<half2, 1>; // quantization scales
49+
using FragS = Vec<half2, 1>; // quantization scales
4850

49-
} // namespace torchao
51+
} // namespace torchao

0 commit comments

Comments
 (0)