Skip to content

Commit 55b9f6e

Browse files
committed
testing HQQ [not for land]
Summary: for eval=5 wikitext: {'word_perplexity,none': 11.49343838017535, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6110947678444059, 'byte_perplexity_stderr,none': for eval all ... Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e1564ea Pull Request resolved: #155
1 parent 095b222 commit 55b9f6e

File tree

3 files changed

+77
-2
lines changed

3 files changed

+77
-2
lines changed

generate.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,16 @@ def _load_model(checkpoint_path, device, precision, use_tp):
224224
simple_quantizer = WeightOnlyInt8QuantHandler(model)
225225
model = simple_quantizer.convert_for_runtime()
226226

227-
if "int4" in str(checkpoint_path):
227+
if "int4-hqq" in str(checkpoint_path):
228+
print("Using int4 weight-only HQQ quantization.")
229+
from quantize import WeightOnlyInt4HqqQuantHandler
230+
path_comps = checkpoint_path.name.split(".")
231+
assert path_comps[-3].startswith("g")
232+
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!"
233+
groupsize = int(path_comps[-3][1:])
234+
quantizer = WeightOnlyInt4HqqQuantHandler(model, groupsize=groupsize)
235+
model = quantizer._convert_for_runtime()
236+
elif "int4" in str(checkpoint_path):
228237
print("Using int4 weight-only quantization!")
229238
path_comps = checkpoint_path.name.split(".")
230239
assert path_comps[-3].startswith("g")

quantize.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
519519
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
520520
)
521521

522+
# TODO a hacky placeholder class
523+
class WeightOnlyInt4HqqQuantHandler:
524+
def __init__(self, mod, groupsize):
525+
self.mod = mod
526+
self.groupsize = groupsize
527+
528+
def _create_quantized_state_dict(self):
529+
from hqq.core.quantize import Quantizer # TODO maybe torchao
530+
531+
for m in self.mod.modules():
532+
for name, child in m.named_children():
533+
if isinstance(child, torch.nn.Linear):
534+
child.weight = torch.nn.Parameter(
535+
Quantizer.dequantize(
536+
*Quantizer.quantize(
537+
child.weight,
538+
nbits=4,
539+
group_size=self.groupsize,
540+
axis=1,
541+
)
542+
)
543+
)
544+
545+
return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()
546+
547+
def _convert_for_runtime(self):
548+
return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)
522549

523550
def quantize(
524551
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
@@ -592,6 +619,18 @@ def quantize(
592619
dir_name = checkpoint_path.parent
593620
base_name = checkpoint_path.name
594621
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth")
622+
623+
elif mode == 'int4-hqq':
624+
print("Quantizing model weights for int4 using HQQ")
625+
quant_handler = WeightOnlyInt4HqqQuantHandler(model, groupsize)
626+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
627+
assert tokenizer_path.is_file(), tokenizer_path
628+
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
629+
630+
quantized_state_dict = quant_handler._create_quantized_state_dict()
631+
dir_name = checkpoint_path.parent
632+
base_name = checkpoint_path.name
633+
new_base_name = base_name.replace('.pth', f"{label}int4-hqq.g{groupsize}.{device}.pth")
595634
else:
596635
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
597636

@@ -606,7 +645,7 @@ def quantize(
606645
import argparse
607646
parser = argparse.ArgumentParser(description='Quantize a model.')
608647
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
609-
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
648+
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq', 'int4-hqq'], help='type of quantization to perform')
610649
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
611650
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
612651
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')

run.sh

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
2+
3+
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
4+
# echo "base"
5+
# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
6+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5
7+
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5
8+
9+
10+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-hqq
11+
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --compile
12+
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --tasks wikitext
13+
14+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
15+
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --compile
16+
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext
17+
18+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
19+
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
20+
# broken
21+
22+
# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
23+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
24+
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
25+
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth
26+
27+
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5

0 commit comments

Comments
 (0)