Skip to content

Add torchao mps ops #1415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2024
Merged

Add torchao mps ops #1415

merged 1 commit into from
Dec 13, 2024

Conversation

manuelcandales
Copy link
Contributor

@manuelcandales manuelcandales commented Dec 10, 2024

This PR adds the quantization scheme linear:afpwx. It quantizes only the weights in a groupwise manner with a specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize (32, 64, 128, 256).

To use linear:afpwx, you must first set up the torchao mps experimental kernels. These will only work on a device with Apple Silicon.

From the torchchat root directory, run

bash torchchat/utils/scripts/build_torchao_ops.sh mps

Notice that this quantization scheme is currently implemented only for device mps.

python3 torchchat.py generate stories110M --device mps --dtype float32 --quantize '{"linear:afpwx": {"bitwidth": 4, "groupsize": 256}}' --prompt "Once upon a time,"

Copy link

pytorch-bot bot commented Dec 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1415

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 2 Unrelated Failures

As of commit c164f88 with merge base 4dc2f89 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

### Use

#### linear:fpaxw
The quantization scheme linear:fpaxw quantizes only the weights in a groupwise manner with a specified bitwidth and groupsize.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep the naming convention that torchchat uses (of "a" followed by type and "w" followed by type). This started with a8w4dq before I added any kernels. In your case this would be something like afpwx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this makes sense.


#### Eager mode
```
python3 torchchat.py generate stories110M --device mps --dtype float32 --quantize '{"linear:fpaxw": {"bitwidth": 4, "groupsize": 256}}' --prompt "Once upon a time," --num-samples 5
Copy link
Contributor

@metascroy metascroy Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these only work with eager? If so, explicitly say that in the set-up section?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metal lowbit kernels run with ExecuTorch as well (llama runner can use them). However, my aim in this torchchat PR was only to enable eager. I plan to have a follow up PR to enable them via the torchchat ET path as well. But I prefer to keep it modular.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a sentence in the setup section, clarifying that currently torchchat can only use them on Eager mode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a CI test for the MPS kernels to make sure they install and run?

See https://github.com/pytorch/torchchat/blob/main/.github/workflows/pull.yml#L1060 as an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@manuelcandales manuelcandales force-pushed the torchao-mps branch 3 times, most recently from 270044b to 75804d8 Compare December 13, 2024 18:56
@manuelcandales manuelcandales merged commit 570aebc into main Dec 13, 2024
51 of 53 checks passed
vmpuri pushed a commit that referenced this pull request Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants