Skip to content

[FEAT] Add custom CUDA tinygemm unpacker #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import torch
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.optests import opcheck
Expand All @@ -7,6 +8,8 @@
from parameterized import parameterized
import pytest

import torchao.quantization

try:
import torchao.ops
except RuntimeError:
Expand Down Expand Up @@ -55,6 +58,141 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK):
relative_error = error / results_fp16.abs()
assert relative_error.mean() < 1e-2

## Tests for `unpack_int4_packed`
kTileSizeN = 8
kTileSizeK = 16

SHAPES = [
(4096, 4096),
# Llama 2 GEMM shapes
(4096, 11008),
(11008, 4096),
# Llama 3 GEMM shapes
(4096, 14336),
(14336, 4096),
]
INNERKTILES = [2, 4, 8]
QGROUP_SIZES = [32, 64, 128, 256]
TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES))
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))

@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK, ids=str)
def test_int4_unpack_correctness(shape, innerKTiles):
N, K = shape
assert K % (innerKTiles * kTileSizeK) == 0 and N % kTileSizeN == 0

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles)
unpacked = torchao.ops.unpack_int4_to_int(packed_w, innerKTiles)
assert torch.allclose(t, unpacked)

# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK , ids=str)
def test_int4_unpack_op(shape, innerKTiles):
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
# "test_aot_dispatch_dynamic",
]
t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles)

opcheck(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which pytorch version are you using? it seems this opcheck is moved to torch.library.opcheck: https://github.com/pytorch/pytorch/pull/124496/files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch 2.5.0.dev20240624+cu121

torch.ops.torchao.unpack_int4_to_int,
(packed_w, innerKTiles),
test_utils=test_utils,
)

def dequant_ref(q, scales, zeros, group_size, dtype=torch.bfloat16):
n, k = q.shape
assert q.dtype == torch.int

n_groups = k // group_size
assert scales.shape[0] == n and scales.shape[1] == n_groups
assert scales.shape == zeros.shape

q_bf16 = q.to(dtype=dtype)
q_bf16 = q_bf16.reshape(-1, group_size)
dq = (q_bf16 - zeros.reshape(-1, 1)) * scales.reshape(-1, 1)
return dq.reshape(n, k)

@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_int4_correctness(shape, innerKTiles, group_size):
n, k = shape

# tinygemm params
nTileSize = 8
kTileSize = 16
nTiles = n // nTileSize
kTiles = k // (innerKTiles * kTileSize)
numThreads = 32

device = "cuda"
q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
packed_w = torch._convert_weight_to_int4pack(q, innerKTiles)
# tinygemm params
assert packed_w.shape == torch.Size([nTiles, kTiles, numThreads, innerKTiles // 2])

# scales and zeros init
q_groups = k // group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
zeros = torch.randn_like(scales)

scales_and_zeros = torchao.quantization.utils.pack_tinygemm_scales_and_zeros(scales, zeros)
assert scales_and_zeros.shape == torch.Size([q_groups, n, 2])
scales_unpacked, zeros_unpacked = torchao.quantization.utils.unpack_tinygemm_scales_and_zeros(scales_and_zeros)
assert torch.allclose(scales_unpacked.reshape(scales.shape), scales)
assert torch.allclose(zeros_unpacked.reshape(zeros.shape), zeros)

dq_ref = dequant_ref(q, scales, zeros, group_size)
dq = torchao.ops.dequantize_int4(packed_w, scales_and_zeros, group_size, innerKTiles)
assert torch.allclose(dq, dq_ref, atol=1e-4, rtol=1e-4)

# TODO: Figure out why this fails
# This is how torchao.dtypes.affine_quantized_tensor recovers the original tensor
# https://github.com/pytorch/ao/blob/9dc2c118f59ad4135a8c39166c4ceebda73c62a9/torchao/dtypes/affine_quantized_tensor.py#L505
# a_eye = torch.eye(k, device=device, dtype=torch.bfloat16)
# dq_check = torch.ops.aten._weight_int4pack_mm(
# a_eye,
# packed_w,
# group_size,
# scales_and_zeros,
# ).t()
# assert torch.allclose(dq, dq_check, atol=1e-4, rtol=1e-4)

@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_int4_op(shape, innerKTiles, group_size):
n, k = shape

device = "cuda"
q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
packed_w = torch._convert_weight_to_int4pack(q, innerKTiles)
print(packed_w.shape)
q_groups = k // group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
zeros = torch.randn_like(scales)
scales_and_zeros = torchao.quantization.utils.pack_tinygemm_scales_and_zeros(scales, zeros)

test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
# "test_aot_dispatch_dynamic",
]
opcheck(
torch.ops.torchao.dequantize_int4,
(packed_w, scales_and_zeros, group_size, innerKTiles),
test_utils=test_utils,
)

if __name__ == "__main__":
unittest.main()
Loading
Loading