Skip to content

Commit eb9bc71

Browse files
authored
feat: add rotary_dim argument to rope APIs for partial apply rope (#599)
This PR implements the final piece of #530 , so that we can partially apply rotary embedding to first head dimensions instead of entire head dimensions. We also add a simple benchmark for RoPE, below is the result on H100: ```python batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 23us, throughput: 0.876GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 26us, throughput: 0.801GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.735GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 95.639GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 31us, throughput: 672.889GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 32us, throughput: 662.972GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 14.559GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 14.435GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 37us, throughput: 1339.450GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 37us, throughput: 1340.399GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 148us, throughput: 2696.563GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 148us, throughput: 2689.104GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 74.186GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 74.452GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 110us, throughput: 2350.830GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 110us, throughput: 2359.814GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 717us, throughput: 2895.389GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 718us, throughput: 2891.385GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.449GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 95.646GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 130us, throughput: 2576.101GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 130us, throughput: 2582.447GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 924us, throughput: 2906.154GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 925us, throughput: 2903.484GB/s ```
1 parent 2043ca2 commit eb9bc71

File tree

8 files changed

+493
-197
lines changed

8 files changed

+493
-197
lines changed

benchmarks/bench_rope.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import argparse
2+
from typing import cast
3+
4+
import torch
5+
from triton.testing import do_bench
6+
7+
import flashinfer
8+
9+
10+
def generate_cos_sin_f32_cache(max_seq_len, head_dim, theta=1e4):
11+
position = torch.arange(max_seq_len).float().unsqueeze(1)
12+
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
13+
freqs = torch.cat([freqs, freqs], dim=-1).contiguous()
14+
args = position * freqs
15+
sin_cache = torch.sin(args)
16+
cos_cache = torch.cos(args)
17+
return cos_cache, sin_cache
18+
19+
20+
@torch.inference_mode()
21+
def main():
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 19, 99, 128])
24+
parser.add_argument("--append-len", nargs="+", type=int, default=[1, 128, 1024])
25+
parser.add_argument("--num-qo-heads", type=int, default=32)
26+
parser.add_argument("--num-kv-heads", type=int, default=8)
27+
parser.add_argument("--head-dim", type=int, default=128)
28+
args = parser.parse_args()
29+
30+
eps = 1e-6
31+
dtype = torch.float16
32+
num_qo_heads = args.num_qo_heads
33+
num_kv_heads = args.num_kv_heads
34+
head_dim = args.head_dim
35+
36+
# Loop over each combination of batch_size, hidden_size, and dtype
37+
for batch_size in args.batch_sizes:
38+
for append_len in args.append_len:
39+
for use_cos_sin_cache in [False, True]:
40+
# Define tensors with the correct dtype
41+
42+
q = torch.randn(
43+
(batch_size * append_len, num_qo_heads, args.head_dim),
44+
dtype=dtype,
45+
device="cuda",
46+
)
47+
k = torch.randn(
48+
(batch_size * append_len, num_kv_heads, args.head_dim),
49+
dtype=dtype,
50+
device="cuda",
51+
)
52+
pos_ids = torch.repeat_interleave(
53+
torch.arange(append_len, dtype=torch.int32, device=q.device),
54+
batch_size,
55+
)
56+
cos_cache, sin_cache = generate_cos_sin_f32_cache(4096, head_dim)
57+
cos_cache = cos_cache.to(q.device)
58+
sin_cache = sin_cache.to(q.device)
59+
60+
@torch.cuda.nvtx.range(
61+
f"apply_rope batch_size={batch_size}, append_len={append_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}"
62+
)
63+
def fn() -> None:
64+
if use_cos_sin_cache:
65+
flashinfer.apply_rope_with_cos_sin_cache(
66+
q, k, cos_cache, sin_cache, pos_ids
67+
)
68+
else:
69+
flashinfer.apply_rope_pos_ids(q, k, pos_ids)
70+
71+
# Run benchmarking
72+
latency_ms = cast(float, do_bench(fn))
73+
throughput = (
74+
q.numel() * q.element_size() * 2 + k.numel() * k.element_size() * 2
75+
) / (latency_ms * 1e-3)
76+
print(
77+
f"batch_size: {batch_size:3},",
78+
f"append_len: {append_len:5},",
79+
f"num_qo_heads: {num_qo_heads:5},",
80+
f"num_kv_heads: {num_kv_heads:5},",
81+
f"head_dim: {head_dim:5},",
82+
f"use_cos_sin_cache: {use_cos_sin_cache},",
83+
f"latency: {latency_ms*1e3:2.0f}us,",
84+
f"throughput: {throughput*1e-9:7.3f}GB/s",
85+
)
86+
87+
print("---")
88+
89+
torch.cuda.profiler.stop()
90+
91+
92+
if __name__ == "__main__":
93+
main()

include/flashinfer/cutlass_utils.cuh

-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
#ifndef FLASHINFER_CUTLASS_UTILS_CUH_
1717
#define FLASHINFER_CUTLASS_UTILS_CUH_
1818

19-
#include <cuda_runtime.h>
20-
#include <cutlass/cutlass.h>
21-
2219
#include "cute/tensor.hpp"
2320
#include "cutlass/cutlass.h"
2421
#include "cutlass/epilogue/collective/collective_builder.hpp"

0 commit comments

Comments
 (0)