Skip to content

high throughput inference #663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
msaroufim opened this issue Aug 12, 2024 · 3 comments
Open

high throughput inference #663

msaroufim opened this issue Aug 12, 2024 · 3 comments

Comments

@msaroufim
Copy link
Member

msaroufim commented Aug 12, 2024

Was chatting with @Chillee about our plans in AO today and he mentioned we should be focusing on a few concrete problems like

  1. Demonstrate compelling perf for fp8 gemm at a variety of batch sizes.
  2. Demonstrate compelling perf for weight only int8 gemm at a variety of batch sizes.
  3. Demonstrate compelling perf for weight only intX gemm at low batch sizes.
  4. Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes.

We could as a baseline extend gpt-fast to work with bs=n w/o doing any kv cache management work and measure perf there. Copying feedback as is, open to discussing more and adding more details as time progresses

EDIT: gpt-fast already has a batched generation branch by Horace https://github.com/pytorch-labs/gpt-fast/tree/batched_generation

@msaroufim msaroufim changed the title chilli feedback high throughput inference Aug 12, 2024
@msaroufim
Copy link
Member Author

@HDCharles on the int8 work
@vkuzo on fp8
@vayuda and @jerryzh168 on intx

@jeromeku
Copy link
Contributor

@msaroufim

Would be interesting to bench against something like QoQ, which implements W4A8KV4 (int8 GEMM) using a nested quantization scheme and neat kernel-level optimizations.

@vkuzo
Copy link
Contributor

vkuzo commented Aug 13, 2024

Demonstrate compelling perf for fp8 gemm at a variety of batch sizes.

Note that I'm putting up a PR soon for a quick roofline estimator for float8 gemm + overhead specific to training to see for which M, K, N float8 is faster than bfloat16, it would be easiliy extendable to inference at a later time.

Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes.

While this is possible technically, I'm not sure I understand the value, would be interested to learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants