We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error message
File "/home/weiwen/torchao/torchao/quantization/pt2e/pt2e/_affine_quantization.py", line 739, in forward self.min_val.shape == min_val.shape AssertionError: Can't update existing min_val - shape mismatch, self.min_val:torch.Size([]) != min_val:torch.Size([1, 1])
How to reproduce
import torch import torchao import torchvision import copy from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, ) import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torchao.quantization.pt2e import move_exported_model_to_eval from torch.ao.quantization.pt2e.lowering import lower_pt2e_quantized_to_x86 def get_quant_config(): from torchao.quantization.pt2e.pt2e._affine_quantization import ( AffineQuantizedMinMaxObserver, AffineQuantizedMovingAverageMinMaxObserver, ) from torchao.quantization.pt2e.observer import MappingType, PerAxis, PerTensor from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpec from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import QuantizationConfig act_quantization_spec = QuantizationSpec( dtype=torch.uint8, quant_min=0, quant_max=255, qscheme=torch.per_tensor_affine, is_dynamic=False, observer_or_fake_quant_ctr=AffineQuantizedMovingAverageMinMaxObserver.with_args( target_dtype=torch.uint8, mapping_type=MappingType.ASYMMETRIC, granularity=PerTensor(), averaging_constant=1, ), ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, ch_axis=0, # 0 corresponding to oc is_dynamic=False, observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( target_dtype=torch.int8, mapping_type=MappingType.SYMMETRIC, granularity=PerAxis(0), ), ) bias_quantization_spec = None quantization_config = QuantizationConfig( act_quantization_spec, act_quantization_spec, weight_quantization_spec, bias_quantization_spec, ) return quantization_config def pt2e_ptq(m, example_inputs): m = m.eval() exported_model = torch.export.export_for_training(m, example_inputs, strict=True).module() quantizer = X86InductorQuantizer() quantizer.set_global(get_quant_config()) prepared_model = prepare_pt2e(exported_model, quantizer) print("calibration") _ = prepared_model(*example_inputs) converted_model = convert_pt2e(prepared_model) move_exported_model_to_eval(converted_model) print("[info] converted_model =\n", converted_model) if __name__ == "__main__": data = torch.randn(16, 3, 224, 224) inputs = (data,) model_fp = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) pt2e_ptq(copy.deepcopy(model_fp), inputs)
Investigation findings The cause is that the per-tensor observer gets an input which has dims of size = 1. Then the predicate here
ao/torchao/quantization/pt2e/pt2e/_affine_quantization.py
Line 136 in 2266451
reduction_dims
ao/torchao/quantization/quant_primitives.py
Line 277 in 2266451
The text was updated successfully, but these errors were encountered:
thanks for reporting the issue, I'll take a look
Sorry, something went wrong.
jerryzh168
No branches or pull requests
Error message
How to reproduce
Investigation findings
The cause is that the per-tensor observer gets an input which has dims of size = 1. Then the predicate here
ao/torchao/quantization/pt2e/pt2e/_affine_quantization.py
Line 136 in 2266451
reduction_dims
.And same issue here:
ao/torchao/quantization/quant_primitives.py
Line 277 in 2266451
The text was updated successfully, but these errors were encountered: