Skip to content

Arm_inductor_quantizer for Pt2e quantization #2139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

choudhary-devang
Copy link

@choudhary-devang choudhary-devang commented Apr 28, 2025

Title: Enable PyTorch 2 Export Quantization path for ARM CPUs.

Description:

  • This PR extends the PyTorch 2 Export Quantization (PT2E Quantization) workflow—originally available only on x86 CPUs—to support ARM platforms. PT2E Quantization is an automated, full-graph quantization solution in PyTorch that improves on Eager Mode Quantization by adding support for functionals and automating the overall process. It is part of the torch.ao module and fully supports quantization when using the compile mode.

Key Changes:

  • Introduces ARM-specific support by leveraging oneDNN kernels for matmuls and convolution.

  • Integrates pre-defined configuration selection to automatically choose the best quantization settings based on the selected quantization method.

Provides customization options via two flags:

  • qat_state: Indicates whether to use Quantization Aware Training (if set to True) or Post Training Quantization (if set to False). The default remains False.
  • dynamic_state: Selects between dynamic quantization (if True) and static quantization (if False). The default is also set to False.
    Screenshot 2025-01-22 105543

These options allow users to tailor the quantization process for their specific workload requirements (e.g., using QAT for fine-tuning or PTQ for calibration-based quantization).

Testing and Validation:

The new ARM flow has been thoroughly tested across a range of models with all combinations:
NLP: Models such as BERT and T5.
Vision: Models like ResNet and ViT.
Custom Models: user defined models with various operators.

example script:

import torch
from transformers import BertModel
import copy
import time
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as aiq
from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import ArmInductorQuantizer
import torch.profiler
import torch._inductor.config as config
# Enable C++ wrapper for Inductor
config.cpp_wrapper = True
config.freezing=True

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)

# Set the model to eval mode
model = model.eval()

# Create the data, using dummy data here as an example
traced_bs = 32
seq_length = 128
x = torch.randint(0, 10000, (traced_bs, seq_length))
attention_mask = torch.ones((traced_bs, seq_length))
example_inputs = (x, attention_mask)

# Capture the FX Graph to be quantized
with torch.no_grad():
    exported_model = torch.export.export_for_training(model, example_inputs).module()

    # Set up the quantizer and prepare the model for post-training quantization
    quantizer = ArmInductorQuantizer()
    quantizer.set_global(aiq.get_default_arm_inductor_quantization_config(is_dynamic=True))
    prepared_model = prepare_pt2e(exported_model, quantizer)

    # Run the prepared model to apply the quantization
    prepared_model(*example_inputs)

    # Convert the model to the quantized version
    converted_model = convert_pt2e(prepared_model)
    optimized_model = torch.compile(converted_model)
    st = time.time()
    optimized_model(*example_inputs)
    et = time.time()
    print(f"Average time required for inference = {et-st}\n")


cc: @jerryzh168, @fadara01, @Xia-Weiwen

Copy link

pytorch-bot bot commented Apr 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2139

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

Hi @choudhary-devang!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@choudhary-devang
Copy link
Author

This pr is moved from torch to torchao due to the Pt2e migration to torchao

old pr link
(pytorch/pytorch#146690)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants