|
10 | 10 | run_tests,
|
11 | 11 | )
|
12 | 12 | from torch.testing._internal.optests import opcheck
|
13 |
| -from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5 |
| 13 | +from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff |
14 | 14 | from torchao.prototype.quant_llm import from_scaled_tc_fpx
|
| 15 | +from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24 |
15 | 16 | import pytest
|
16 | 17 |
|
17 | 18 | if is_fbcode():
|
|
22 | 23 | except RuntimeError:
|
23 | 24 | pytest.skip("torchao.ops not available")
|
24 | 25 |
|
25 |
| -from torchao.sparsity.utils import mask_creator |
26 |
| -from torchao.sparsity.marlin import ( |
27 |
| - pack_to_sparse_marlin_24, |
28 |
| - marlin_24_mm, |
29 |
| - fp16_to_int4_marlin_format |
30 |
| -) |
31 | 26 | from torchao.quantization.utils import (
|
32 | 27 | get_groupwise_affine_qparams,
|
33 | 28 | groupwise_affine_dequantize_tensor_from_qparams,
|
@@ -309,139 +304,117 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
|
309 | 304 | )
|
310 | 305 |
|
311 | 306 |
|
312 |
| -class SparseMarlin24(TestCase): |
313 |
| - TILES = 16 |
| 307 | +MARLIN_24_K_CHUNKS = [128] |
| 308 | +MARLIN_24_N_CHUNKS = [512] |
| 309 | +MNK_FACTORS = [ |
| 310 | + (1, 1, 1), |
| 311 | + (1, 4, 8), |
| 312 | + (1, 7, 5), |
| 313 | + (13, 17, 67), |
| 314 | + (26, 37, 13), |
| 315 | + (67, 13, 11), |
| 316 | +] |
| 317 | +MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] |
| 318 | +MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] |
314 | 319 |
|
315 |
| - def _op_check(self, inputs, sparse_w_int4, meta, scales, workspace, thread_k, thread_m, sms=-1, max_par=16): |
316 |
| - out = torch.empty((inputs.size(0), scales.size(1)), dtype=inputs.dtype, device=inputs.device) |
| 320 | +MARLIN_TEST_PARAMS = list(itertools.product( |
| 321 | + MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS, |
| 322 | + MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS |
| 323 | +)) |
317 | 324 |
|
318 |
| - prob_n = inputs.size(0) |
319 |
| - prob_m = out.size(1) |
320 |
| - prob_k = inputs.size(1) |
321 |
| - group_size = -1 if scales.size(0) == 1 else int(prob_k / 2 / scales.size(0)) |
322 |
| - device = torch.cuda.current_device() |
| 325 | +def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int): |
| 326 | + orig_device = w.device |
| 327 | + size_k, size_n = w.shape |
323 | 328 |
|
324 |
| - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] |
325 |
| - opcheck( |
326 |
| - torch.ops.torchao.marlin_24_mm, |
327 |
| - ( |
328 |
| - inputs, sparse_w_int4, meta, out, scales, prob_m, prob_n, prob_k, |
329 |
| - workspace, group_size, device, thread_k, thread_m, sms, max_par |
330 |
| - ), |
331 |
| - test_utils=test_utils, |
332 |
| - ) |
333 |
| - |
334 |
| - def _gen_values(self, m, n, k, group_size): |
335 |
| - maxq = 2**4 - 1 |
336 |
| - inputs = torch.randn((n, k), dtype=torch.half, device="cuda") |
337 |
| - w = torch.randn((m, k), dtype=torch.half, device="cuda") |
338 |
| - |
339 |
| - w = w.t() |
340 |
| - if group_size != -1: |
341 |
| - w = w.reshape((-1, group_size, m)) |
342 |
| - w = w.permute(1, 0, 2) |
343 |
| - w = w.reshape((group_size, -1)) |
| 329 | + assert w.is_floating_point(), "w must be float" |
344 | 330 |
|
345 |
| - scales = torch.max(torch.abs(w), 0, keepdim=True)[0] |
346 |
| - scales *= 2 / maxq |
| 331 | + if group_size == -1: |
| 332 | + group_size = size_k |
| 333 | + assert group_size <= size_k |
347 | 334 |
|
348 |
| - w = torch.round(w / scales).int() |
349 |
| - w += (maxq + 1) // 2 |
350 |
| - w = torch.clamp(w, 0, maxq) |
| 335 | + max_q_val = 2**num_bits - 1 |
| 336 | + half_q_val = (max_q_val + 1) // 2 |
351 | 337 |
|
352 |
| - w_fp16 = (w - (maxq + 1) // 2).half() * scales |
353 |
| - scales = scales.reshape((-1, m)).contiguous() |
| 338 | + # Reshape to [groupsize, -1] |
| 339 | + if group_size < size_k: |
| 340 | + w = w.reshape((-1, group_size, size_n)) |
| 341 | + w = w.permute(1, 0, 2) |
| 342 | + w = w.reshape((group_size, -1)) |
354 | 343 |
|
355 |
| - if group_size != -1: |
| 344 | + # Compute scale for each group |
| 345 | + s = torch.max(torch.abs(w), 0, keepdim=True)[0] |
| 346 | + s *= 2 / max_q_val # 2 => symmetric |
356 | 347 |
|
357 |
| - def reshape(w): |
358 |
| - w = w.reshape((group_size, -1, m)) |
359 |
| - w = w.permute(1, 0, 2) |
360 |
| - w = w.reshape((k, m)).contiguous() |
361 |
| - return w |
| 348 | + # Quantize |
| 349 | + q_w = torch.round(w / s).int() |
| 350 | + q_w += half_q_val |
| 351 | + q_w = torch.clamp(q_w, 0, max_q_val) |
362 | 352 |
|
363 |
| - w_fp16 = reshape(w_fp16) |
364 |
| - w = reshape(w) |
365 |
| - |
366 |
| - mask = mask_creator(w.T).cuda().bool() |
367 |
| - sparse_w_fp16_ref = (mask * w_fp16.T).T |
| 353 | + # Compute ref (dequantized) |
| 354 | + w_ref = (q_w - half_q_val).half() * s |
368 | 355 |
|
369 |
| - return inputs, sparse_w_fp16_ref, w_fp16, scales |
| 356 | + # Restore original shapes |
| 357 | + if group_size < size_k: |
370 | 358 |
|
371 |
| - def _run_problem(self, m, n, k, thread_k, thread_m, group_size=-1): |
372 |
| - inputs, sparse_w_fp16_ref, w_fp16, scales = self._gen_values(m, n, k, group_size) |
373 |
| - out_ref = torch.matmul(inputs, sparse_w_fp16_ref) |
| 359 | + def reshape_w(w): |
| 360 | + w = w.reshape((group_size, -1, size_n)) |
| 361 | + w = w.permute(1, 0, 2) |
| 362 | + w = w.reshape((size_k, size_n)).contiguous() |
| 363 | + return w |
374 | 364 |
|
375 |
| - # If no groupsize is provided, we assume it is the same as the in_features of the weights |
376 |
| - # https://github.com/IST-DASLab/Sparse-Marlin/blob/c2ffa2395a3ada26c8cb7f910a5ec65bd3ce288a/marlin/__init__.py#L290 |
377 |
| - if group_size == -1: |
378 |
| - group_size = k |
| 365 | + q_w = reshape_w(q_w) |
| 366 | + w_ref = reshape_w(w_ref) |
379 | 367 |
|
380 |
| - w_int4, scales = fp16_to_int4_marlin_format(w_fp16, scales, group_size) |
381 |
| - sparse_w_int4, scales, meta = pack_to_sparse_marlin_24(w_int4, scales, self.TILES) |
| 368 | + s = s.reshape((-1, size_n)).contiguous() |
382 | 369 |
|
383 |
| - workspace = torch.zeros(m // 128 * 16, device="cuda", dtype=torch.int32) |
384 |
| - out = marlin_24_mm(inputs, sparse_w_int4, meta, scales, workspace, thread_k, thread_m, -1) |
385 |
| - torch.cuda.synchronize() |
| 370 | + return ( |
| 371 | + w_ref.to(device=orig_device), |
| 372 | + q_w.to(device=orig_device), |
| 373 | + s.to(device=orig_device), |
| 374 | + ) |
386 | 375 |
|
387 |
| - self.assertLess( |
388 |
| - torch.mean(torch.abs(out - out_ref)) / torch.mean(torch.abs(out_ref)), 0.002 |
389 |
| - ) |
| 376 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 377 | +@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str) |
| 378 | +def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors): |
| 379 | + m_factor, n_factor, k_factor = mnk_factors |
390 | 380 |
|
391 |
| - # TODO(diogo): Enable this check once I understand how to make `out` mutable |
392 |
| - # self._op_check(inputs, sparse_w_int4, meta, scales, workspace, thread_k, thread_m) |
| 381 | + size_m = m_factor |
| 382 | + size_k = k_chunk * k_factor |
| 383 | + size_n = n_chunk * n_factor |
393 | 384 |
|
394 |
| - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
395 |
| - def test_correctness(self): |
396 |
| - self._run_problem(256, 16, 256, 128, 128, -1) |
397 |
| - self._run_problem(21504, 16, 4096, 64, 256, 128) |
| 385 | + a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda") |
| 386 | + b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda") |
398 | 387 |
|
399 |
| - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
400 |
| - def test_tiles(self): |
401 |
| - for m in [1, 2, 4, 8, 12, 16, 32, 64]: |
402 |
| - for thread_k, thread_n in [(64, 256), (128, 128)]: |
403 |
| - if m > 16 and thread_k == 128: |
404 |
| - continue |
405 |
| - self._run_problem(2 * 256, m, 1024, thread_k, thread_n) |
| 388 | + # Inject 2:4 sparsity |
| 389 | + w_24, _ = inject_24(b_weight, size_k, size_n) |
406 | 390 |
|
407 |
| - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
408 |
| - def test_k_stages_divisibility(self): |
409 |
| - for k in [3 * 64 + 64 * 4 * 2 + 64 * i for i in range(1, 4)]: |
410 |
| - self._run_problem(2 * 256, 16, k, 64, 256) |
| 391 | + # Symmetric quantize |
| 392 | + w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size) |
411 | 393 |
|
412 |
| - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
413 |
| - def test_very_few_stages(self): |
414 |
| - for k in [64, 128, 192]: |
415 |
| - self._run_problem(3 * 256, 16, k, 64, 256) |
| 394 | + # Obtains reference output |
| 395 | + output_ref = torch.matmul(a_input, w_24_ref) |
416 | 396 |
|
417 |
| - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
418 |
| - def test_llama_shapes(self): |
419 |
| - MODELS = { |
420 |
| - " 7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], |
421 |
| - "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], |
422 |
| - "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], |
423 |
| - "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], |
424 |
| - } |
425 |
| - |
426 |
| - try: |
427 |
| - for _, layers in MODELS.items(): |
428 |
| - for layer in layers: |
429 |
| - for thread_k, thread_m in [(128, 128)]: |
430 |
| - for batch in [16]: |
431 |
| - print(layer[1], batch, layer[0]) |
432 |
| - self._run_problem(layer[1], batch, layer[0], thread_k, thread_m) |
433 |
| - # If someone runs this on a GPU with less than 24GB of memory, it will run out of memory |
434 |
| - # but we don't want to fail the test |
435 |
| - except torch.OutOfMemoryError: |
436 |
| - pass |
| 397 | + # Packs to marlin 2:4 |
| 398 | + marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size) |
| 399 | + workspace_24 = marlin_24_workspace(size_n) |
437 | 400 |
|
438 |
| - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
439 |
| - def test_groups(self): |
440 |
| - for m in [16]: |
441 |
| - for groupsize in [128]: |
442 |
| - for n, k in [(256, 512), (256, 1024), (256 * 128, 1024)]: |
443 |
| - for thread_shape in [(128, 128), (64, 256)]: |
444 |
| - self._run_problem(n, m, k, *thread_shape, groupsize) |
| 401 | + fn_inputs = ( |
| 402 | + a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24, |
| 403 | + num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1], |
| 404 | + ) |
| 405 | + output = torchao.ops.marlin_24_gemm(*fn_inputs) |
| 406 | + torch.cuda.synchronize() |
| 407 | + |
| 408 | + max_diff = compute_max_diff(output, output_ref) |
| 409 | + assert max_diff < 0.04 |
| 410 | + |
| 411 | + # Performs opcheck |
| 412 | + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] |
| 413 | + opcheck( |
| 414 | + torch.ops.torchao.marlin_24_gemm, |
| 415 | + fn_inputs, |
| 416 | + test_utils=test_utils, |
| 417 | + ) |
445 | 418 |
|
446 | 419 |
|
447 | 420 | if __name__ == "__main__":
|
|
0 commit comments