Skip to content

Commit f6e0010

Browse files
authored
feat: torch custom_op support: norm (#552)
Add torch custom_op (aka, torch library, torch.compile) support for `norm.py`. It should be a no-op for PyTorch < 2.4. Testing is done by `torch.compile` -- as we expect the custom_op marks can isolate out our kernels during torch.compile. To avoid changes to tests, I introduced some magic that replaces the kernels with a `torch.compile`-ed version. For example, to run with/without torch.compile: ```bash # With torch.compile FLASHINFER_TEST_TORCH_COMPILE=1 pytest -svx tests/test_norm.py # Without torch.compile pytest -svx tests/test_norm.py ``` If this PR looks good, I'll add it to more kernels.
1 parent 47583b3 commit f6e0010

File tree

3 files changed

+143
-10
lines changed

3 files changed

+143
-10
lines changed

python/flashinfer/norm.py

+57-8
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
limitations under the License.
1515
"""
1616

17+
from typing import Optional
18+
1719
import torch
1820

19-
from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops
21+
from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops
22+
from .utils import register_custom_op, register_fake_op
2023

2124
_norm_module = None
2225

@@ -43,7 +46,7 @@ def rmsnorm(
4346
input: torch.Tensor,
4447
weight: torch.Tensor,
4548
eps: float = 1e-6,
46-
out: torch.Tensor = None,
49+
out: Optional[torch.Tensor] = None,
4750
) -> torch.Tensor:
4851
r"""Root mean square normalization.
4952
@@ -65,13 +68,28 @@ def rmsnorm(
6568
"""
6669
if out is None:
6770
out = torch.empty_like(input)
68-
get_norm_module().rmsnorm(out, input, weight, eps)
71+
_rmsnorm(out, input, weight, eps)
6972
return out
7073

7174

75+
@register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
76+
def _rmsnorm(
77+
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float
78+
) -> None:
79+
get_norm_module().rmsnorm(out, input, weight, eps)
80+
81+
82+
@register_fake_op("flashinfer::rmsnorm")
83+
def _rmsnorm_fake(
84+
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float
85+
) -> None:
86+
pass
87+
88+
89+
@register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
7290
def fused_add_rmsnorm(
7391
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
74-
):
92+
) -> None:
7593
r"""Fused add root mean square normalization.
7694
7795
Parameters
@@ -88,12 +106,19 @@ def fused_add_rmsnorm(
88106
get_norm_module().fused_add_rmsnorm(input, residual, weight, eps)
89107

90108

109+
@register_fake_op("flashinfer::fused_add_rmsnorm")
110+
def _fused_add_rmsnorm_fake(
111+
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
112+
) -> None:
113+
pass
114+
115+
91116
def gemma_rmsnorm(
92117
input: torch.Tensor,
93118
weight: torch.Tensor,
94119
eps: float = 1e-6,
95-
out: torch.Tensor = None,
96-
):
120+
out: Optional[torch.Tensor] = None,
121+
) -> torch.Tensor:
97122
r"""Gemma Root mean square normalization.
98123
99124
Parameters
@@ -114,13 +139,30 @@ def gemma_rmsnorm(
114139
"""
115140
if out is None:
116141
out = torch.empty_like(input)
117-
get_norm_module().gemma_rmsnorm(out, input, weight, eps)
142+
_gemma_rmsnorm(out, input, weight, eps)
118143
return out
119144

120145

146+
@register_custom_op("flashinfer::gemma_rmsnorm", mutates_args=("out",))
147+
def _gemma_rmsnorm(
148+
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float
149+
) -> None:
150+
get_norm_module().gemma_rmsnorm(out, input, weight, eps)
151+
152+
153+
@register_fake_op("flashinfer::gemma_rmsnorm")
154+
def _gemma_rmsnorm_fake(
155+
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, eps: float
156+
) -> None:
157+
pass
158+
159+
160+
@register_custom_op(
161+
"flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
162+
)
121163
def gemma_fused_add_rmsnorm(
122164
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
123-
):
165+
) -> None:
124166
r"""Gemma Fused add root mean square normalization.
125167
126168
Parameters
@@ -135,3 +177,10 @@ def gemma_fused_add_rmsnorm(
135177
Epsilon for numerical stability.
136178
"""
137179
get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps)
180+
181+
182+
@register_fake_op("flashinfer::gemma_fused_add_rmsnorm")
183+
def _gemma_fused_add_rmsnorm_fake(
184+
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
185+
) -> None:
186+
pass

python/flashinfer/utils.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
limitations under the License.
1515
"""
1616

17-
import torch
1817
import math
1918
from enum import Enum
20-
from typing import Optional, Tuple, Union, Dict
19+
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union
20+
21+
import torch
22+
from torch.torch_version import TorchVersion
23+
from torch.torch_version import __version__ as torch_version
2124

2225

2326
class PosEncodingMode(Enum):
@@ -197,3 +200,28 @@ def _check_cached_qkv_data_type(
197200
raise ValueError(
198201
f"The dtype of k {k.dtype} does not match the kv_data_type {dtype_kv} specified in plan function."
199202
)
203+
204+
205+
def register_custom_op(
206+
name: str,
207+
fn: Optional[Callable] = None,
208+
/,
209+
*,
210+
mutates_args: Union[str, Iterable[str]],
211+
device_types: Optional[Union[str, Sequence[str]]] = None,
212+
schema: Optional[str] = None,
213+
) -> Callable:
214+
if TorchVersion(torch_version) < TorchVersion("2.4"):
215+
return fn
216+
return torch.library.custom_op(
217+
name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema
218+
)
219+
220+
221+
def register_fake_op(
222+
name: str,
223+
fn: Optional[Callable] = None,
224+
) -> Callable:
225+
if TorchVersion(torch_version) < TorchVersion("2.4"):
226+
return fn
227+
return torch.library.register_fake(name, fn)

tests/conftest.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import types
3+
4+
import flashinfer
5+
import pytest
6+
import torch
7+
from torch.torch_version import TorchVersion
8+
from torch.torch_version import __version__ as torch_version
9+
10+
TORCH_COMPILE_FNS = [
11+
flashinfer.norm.rmsnorm,
12+
flashinfer.norm.fused_add_rmsnorm,
13+
flashinfer.norm.gemma_rmsnorm,
14+
flashinfer.norm.gemma_fused_add_rmsnorm,
15+
]
16+
17+
18+
def _monkeypatch_add_torch_compile(func):
19+
"""
20+
Replace the given function with its torch.compile version.
21+
"""
22+
23+
from torch._library.custom_ops import CustomOpDef
24+
25+
if type(func) is types.FunctionType:
26+
fn = func
27+
elif isinstance(func, CustomOpDef):
28+
fn = func._init_fn
29+
else:
30+
raise ValueError(f"Unsupported fn type {type(func)}")
31+
32+
components = fn.__module__.split(".")
33+
assert components[0] == "flashinfer"
34+
module = flashinfer
35+
for component in components[1:]:
36+
module = getattr(module, component)
37+
38+
setattr(
39+
module,
40+
fn.__name__,
41+
torch.compile(
42+
func,
43+
fullgraph=True,
44+
backend="inductor",
45+
mode="max-autotune-no-cudagraphs",
46+
),
47+
)
48+
print("Applied torch.compile to", f"{fn.__module__}.{fn.__name__}")
49+
50+
51+
def pytest_configure(config):
52+
if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1":
53+
if torch_version < TorchVersion("2.4"):
54+
pytest.skip("torch.compile requires torch >= 2.4")
55+
for fn in TORCH_COMPILE_FNS:
56+
_monkeypatch_add_torch_compile(fn)

0 commit comments

Comments
 (0)