-
Notifications
You must be signed in to change notification settings - Fork 249
Add torchao mps ops #1415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torchao mps ops #1415
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1415
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 2 Unrelated FailuresAs of commit c164f88 with merge base 4dc2f89 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
e0c7ced
to
0f1825c
Compare
0f1825c
to
1cf4a7f
Compare
docs/quantization.md
Outdated
### Use | ||
|
||
#### linear:fpaxw | ||
The quantization scheme linear:fpaxw quantizes only the weights in a groupwise manner with a specified bitwidth and groupsize. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we keep the naming convention that torchchat uses (of "a" followed by type and "w" followed by type). This started with a8w4dq before I added any kernels. In your case this would be something like afpwx?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this makes sense.
docs/quantization.md
Outdated
|
||
#### Eager mode | ||
``` | ||
python3 torchchat.py generate stories110M --device mps --dtype float32 --quantize '{"linear:fpaxw": {"bitwidth": 4, "groupsize": 256}}' --prompt "Once upon a time," --num-samples 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these only work with eager? If so, explicitly say that in the set-up section?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metal lowbit kernels run with ExecuTorch as well (llama runner can use them). However, my aim in this torchchat PR was only to enable eager. I plan to have a follow up PR to enable them via the torchchat ET path as well. But I prefer to keep it modular.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a sentence in the setup section, clarifying that currently torchchat can only use them on Eager mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a CI test for the MPS kernels to make sure they install and run?
See https://github.com/pytorch/torchchat/blob/main/.github/workflows/pull.yml#L1060 as an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
270044b
to
75804d8
Compare
75804d8
to
c164f88
Compare
This PR adds the quantization scheme linear:afpwx. It quantizes only the weights in a groupwise manner with a specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize (32, 64, 128, 256).
To use linear:afpwx, you must first set up the torchao mps experimental kernels. These will only work on a device with Apple Silicon.
From the torchchat root directory, run
Notice that this quantization scheme is currently implemented only for device mps.