Skip to content

Commit 8f5f349

Browse files
authored
feat: warmup for jit kernel tests (#629)
Currently unittests are slow when using flashinfer jit because we only compile kernels the first time we run it, it's blocking and didn't compile multiple ops in parallel. This PR add a warmup pre-hook to kernel unittests, so that we compile all necessary kernels before running the unittests in JIT mode, which greatly accelerate the unittests. This PR also fixes the several issues with #628 : 1. using thread-safe `make_dirs(..., exist_ok=True)` instead of relying on `os.path.exists` 2. change the signature of `parallel_load_modules` to lists of `(jit_module_creation_func, args)` instead of lambda function, because lambda function captures variable by ref instead of value, which may cause some unexpected errors.
1 parent 92ac440 commit 8f5f349

16 files changed

+493
-48
lines changed

python/flashinfer/jit/activation.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def get_act_and_mul_cu_str(act_func_name: str, act_func_def: str) -> str:
6262

6363
def gen_act_and_mul_module(act_func_name: str, act_func_def: str) -> None:
6464
gen_directory = FLASHINFER_GEN_SRC_DIR
65-
if not os.path.exists(gen_directory):
66-
os.makedirs(gen_directory)
65+
os.makedirs(gen_directory, exist_ok=True)
6766
sources = [gen_directory / f"{act_func_name}_and_mul.cu"]
6867
write_if_different(
6968
sources[0],

python/flashinfer/jit/attention.py

-12
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def get_batch_decode_uri(
155155

156156
def gen_batch_decode_module(*args):
157157
gen_directory = FLASHINFER_GEN_SRC_DIR
158-
if not os.path.exists(gen_directory):
159-
os.makedirs(gen_directory)
160158
uri = get_batch_decode_uri(*args)
161159
sources = get_batch_decode_sources(*args)
162160
source_paths = []
@@ -214,8 +212,6 @@ def get_batch_decode_mla_uri(
214212

215213
def gen_batch_decode_mla_module(*args):
216214
gen_directory = FLASHINFER_GEN_SRC_DIR
217-
if not os.path.exists(gen_directory):
218-
os.makedirs(gen_directory)
219215
uri = get_batch_decode_mla_uri(*args)
220216
sources = get_batch_decode_mla_sources(*args)
221217
source_paths = []
@@ -275,8 +271,6 @@ def get_single_prefill_uri(
275271

276272
def gen_single_prefill_module(*args):
277273
gen_directory = FLASHINFER_GEN_SRC_DIR
278-
if not os.path.exists(gen_directory):
279-
os.makedirs(gen_directory)
280274
uri = get_single_prefill_uri(*args)
281275
sources = get_single_prefill_sources(*args)
282276
source_paths = []
@@ -341,8 +335,6 @@ def get_batch_prefill_uri(
341335

342336
def gen_batch_prefill_module(*args):
343337
gen_directory = FLASHINFER_GEN_SRC_DIR
344-
if not os.path.exists(gen_directory):
345-
os.makedirs(gen_directory)
346338
uri = get_batch_prefill_uri(*args)
347339
sources = get_batch_prefill_sources(*args)
348340
source_paths = []
@@ -518,8 +510,6 @@ def get_customize_single_prefill_sources(
518510

519511
def gen_customize_single_decode_module(module_name, *args):
520512
gen_directory = FLASHINFER_GEN_SRC_DIR
521-
if not os.path.exists(gen_directory):
522-
os.makedirs(gen_directory)
523513
sources = get_customize_single_decode_sources(*args)
524514
source_paths = []
525515
for suffix, source in zip(single_decode_suffix, sources):
@@ -532,8 +522,6 @@ def gen_customize_single_decode_module(module_name, *args):
532522

533523
def gen_customize_single_prefill_module(module_name, *args):
534524
gen_directory = FLASHINFER_GEN_SRC_DIR
535-
if not os.path.exists(gen_directory):
536-
os.makedirs(gen_directory)
537525
sources = get_customize_single_prefill_sources(*args)
538526
source_paths = []
539527
for suffix, source in zip(single_prefill_suffix, sources):

python/flashinfer/jit/core.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from .env import FLASHINFER_JIT_DIR as FLASHINFER_JIT_DIR
1515
from .env import FLASHINFER_WORKSPACE_DIR as FLASHINFER_WORKSPACE_DIR
1616

17-
if not os.path.exists(FLASHINFER_WORKSPACE_DIR):
18-
os.makedirs(FLASHINFER_WORKSPACE_DIR)
17+
os.makedirs(FLASHINFER_WORKSPACE_DIR, exist_ok=True)
18+
os.makedirs(FLASHINFER_CSRC_DIR, exist_ok=True)
1919

2020

2121
class FlashInferJITLogger(logging.Logger):
@@ -99,8 +99,7 @@ def load_cuda_ops(
9999
logger.info(f"Loading JIT ops: {name}")
100100
check_cuda_arch()
101101
build_directory = FLASHINFER_JIT_DIR / name
102-
if not os.path.exists(build_directory):
103-
os.makedirs(build_directory, exist_ok=True)
102+
os.makedirs(build_directory, exist_ok=True)
104103
if extra_include_paths is None:
105104
extra_include_paths = [
106105
FLASHINFER_INCLUDE_DIR,

python/flashinfer/jit/utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pathlib
1818
import threading
19-
from typing import Callable, List
19+
from typing import Any, Callable, List, Tuple
2020

2121
import torch
2222

@@ -35,19 +35,19 @@ def write_if_different(path: pathlib.Path, content: str) -> None:
3535

3636

3737
def parallel_load_modules(
38-
load_module_funcs: List[Callable],
38+
load_module_func_args: List[Tuple[Callable, List[Any]]],
3939
):
4040
threads = []
4141
exceptions = []
4242

43-
def wrapper(func):
43+
def wrapper(func, args):
4444
try:
45-
func()
45+
func(*args)
4646
except Exception as e:
4747
exceptions.append((func, e))
4848

49-
for func in load_module_funcs:
50-
thread = threading.Thread(target=wrapper, args=(func,))
49+
for func, args in load_module_func_args:
50+
thread = threading.Thread(target=wrapper, args=(func, args))
5151
thread.start()
5252
threads.append(thread)
5353

tests/jit_utils.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
Copyright (c) 2023 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import itertools
18+
19+
import torch
20+
21+
import flashinfer
22+
23+
24+
def jit_decode_attention_func_args(
25+
q_dtypes,
26+
kv_dtypes,
27+
head_dims,
28+
pos_encoding_modes,
29+
use_sliding_window_options,
30+
use_logits_soft_cap_options,
31+
):
32+
load_module_func_args = []
33+
34+
for (
35+
q_dtype,
36+
kv_dtype,
37+
head_dim,
38+
pos_encoding_mode,
39+
use_sliding_window,
40+
use_logits_soft_cap,
41+
) in itertools.product(
42+
q_dtypes,
43+
kv_dtypes,
44+
head_dims,
45+
pos_encoding_modes,
46+
use_sliding_window_options,
47+
use_logits_soft_cap_options,
48+
):
49+
load_module_func_args.append(
50+
(
51+
flashinfer.decode.get_single_decode_module,
52+
(
53+
q_dtype,
54+
kv_dtype,
55+
q_dtype,
56+
head_dim,
57+
pos_encoding_mode,
58+
use_sliding_window,
59+
use_logits_soft_cap,
60+
),
61+
)
62+
)
63+
load_module_func_args.append(
64+
(
65+
flashinfer.decode.get_batch_decode_module,
66+
(
67+
q_dtype,
68+
kv_dtype,
69+
q_dtype,
70+
torch.int32,
71+
head_dim,
72+
pos_encoding_mode,
73+
use_sliding_window,
74+
use_logits_soft_cap,
75+
),
76+
)
77+
)
78+
79+
return load_module_func_args
80+
81+
82+
def jit_prefill_attention_func_args(
83+
q_dtypes,
84+
kv_dtypes,
85+
head_dims,
86+
pos_encoding_modes,
87+
use_sliding_window_options,
88+
use_logits_soft_cap_options,
89+
allow_fp16_qk_reduction_options,
90+
):
91+
load_module_func_args = []
92+
93+
for (
94+
q_dtype,
95+
kv_dtype,
96+
head_dim,
97+
pos_encoding_mode,
98+
use_sliding_window,
99+
use_logits_soft_cap,
100+
allow_fp16_qk_reduction,
101+
) in itertools.product(
102+
q_dtypes,
103+
kv_dtypes,
104+
head_dims,
105+
pos_encoding_modes,
106+
use_sliding_window_options,
107+
use_logits_soft_cap_options,
108+
allow_fp16_qk_reduction_options,
109+
):
110+
load_module_func_args.append(
111+
(
112+
flashinfer.prefill.gen_single_prefill_module,
113+
(
114+
q_dtype,
115+
kv_dtype,
116+
q_dtype,
117+
head_dim,
118+
pos_encoding_mode,
119+
use_sliding_window,
120+
use_logits_soft_cap,
121+
allow_fp16_qk_reduction,
122+
),
123+
)
124+
)
125+
load_module_func_args.append(
126+
(
127+
flashinfer.prefill.gen_batch_prefill_module,
128+
(
129+
q_dtype,
130+
kv_dtype,
131+
q_dtype,
132+
torch.int32,
133+
head_dim,
134+
pos_encoding_mode,
135+
use_sliding_window,
136+
use_logits_soft_cap,
137+
allow_fp16_qk_reduction,
138+
),
139+
)
140+
)
141+
142+
load_module_func_args.append(
143+
(
144+
flashinfer.quantization.get_quantization_module,
145+
[],
146+
) # required for attention with custom mask
147+
)
148+
149+
return load_module_func_args

tests/test_alibi.py

+32
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,42 @@
1818
import pytest
1919
import torch
2020
from alibi_reference import alibi_attention
21+
from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args
2122

2223
import flashinfer
2324

2425

26+
@pytest.fixture(autouse=True, scope="module")
27+
def warmup_jit():
28+
if flashinfer.jit.has_prebuilt_ops:
29+
return
30+
try:
31+
flashinfer.jit.parallel_load_modules(
32+
jit_decode_attention_func_args(
33+
[torch.float16], # q_dtypes
34+
[torch.float16], # kv_dtypes
35+
[128, 256], # head_dims
36+
[0, 2], # pos_encoding_modes
37+
[False], # use_sliding_windows
38+
[False], # use_logits_soft_caps
39+
)
40+
+ jit_prefill_attention_func_args(
41+
[torch.float16], # q_dtypes
42+
[torch.float16], # kv_dtypes
43+
[128, 256], # head_dims
44+
[0, 2], # pos_encoding_modes
45+
[False], # use_sliding_windows
46+
[False], # use_logits_soft_caps
47+
[False], # allow_fp16_qk_reductions
48+
)
49+
)
50+
except Exception as e:
51+
# abort the test session if warmup fails
52+
pytest.exit(str(e))
53+
finally:
54+
yield
55+
56+
2557
@pytest.mark.parametrize("seq_len", [1, 9, 81, 729])
2658
@pytest.mark.parametrize("num_heads", [4, 8, 32])
2759
@pytest.mark.parametrize("head_dim", [128, 256])

tests/test_batch_decode_kernels.py

+32
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,42 @@
1616

1717
import pytest
1818
import torch
19+
from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args
1920

2021
import flashinfer
2122

2223

24+
@pytest.fixture(autouse=True, scope="module")
25+
def warmup_jit():
26+
if flashinfer.jit.has_prebuilt_ops:
27+
return
28+
try:
29+
flashinfer.jit.parallel_load_modules(
30+
jit_decode_attention_func_args(
31+
[torch.float16], # q_dtypes
32+
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
33+
[128, 256], # head_dims
34+
[0, 1, 2], # pos_encoding_modes
35+
[False], # use_sliding_windows
36+
[False, True], # use_logits_soft_caps
37+
)
38+
+ jit_prefill_attention_func_args(
39+
[torch.float16], # q_dtypes
40+
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
41+
[128, 256], # head_dims
42+
[0, 1, 2], # pos_encoding_modes
43+
[False], # use_sliding_windows
44+
[False, True], # use_logits_soft_caps
45+
[False], # allow_fp16_qk_reductions
46+
)
47+
)
48+
except Exception as e:
49+
# abort the test session if warmup fails
50+
pytest.exit(str(e))
51+
finally:
52+
yield
53+
54+
2355
@pytest.mark.parametrize("batch_size", [12, 17])
2456
@pytest.mark.parametrize("kv_len", [54, 97, 512])
2557
@pytest.mark.parametrize("page_size", [1, 8, 16])

tests/test_batch_prefill_kernels.py

+24
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,34 @@
1616

1717
import pytest
1818
import torch
19+
from jit_utils import jit_prefill_attention_func_args
1920

2021
import flashinfer
2122

2223

24+
@pytest.fixture(autouse=True, scope="module")
25+
def warmup_jit():
26+
if flashinfer.jit.has_prebuilt_ops:
27+
return
28+
try:
29+
flashinfer.jit.parallel_load_modules(
30+
jit_prefill_attention_func_args(
31+
[torch.float16], # q_dtypes
32+
[torch.float16, torch.float8_e4m3fn, torch.float8_e5m2], # kv_dtypes
33+
[128, 256], # head_dims
34+
[0, 1, 2], # pos_encoding_modes
35+
[False], # use_sliding_windows
36+
[False, True], # use_logits_soft_caps
37+
[False], # allow_fp16_qk_reductions
38+
)
39+
)
40+
except Exception as e:
41+
# abort the test session if warmup fails
42+
pytest.exit(str(e))
43+
finally:
44+
yield
45+
46+
2347
@pytest.mark.parametrize("batch_size", [12, 17])
2448
@pytest.mark.parametrize("kv_len", [54, 97])
2549
@pytest.mark.parametrize("qo_len", [37, 17])

0 commit comments

Comments
 (0)