14
14
limitations under the License.
15
15
"""
16
16
17
+ from typing import Optional
18
+
17
19
import torch
18
20
19
- from .jit import load_cuda_ops , FLASHINFER_CSRC_DIR , has_prebuilt_ops
21
+ from .jit import FLASHINFER_CSRC_DIR , has_prebuilt_ops , load_cuda_ops
22
+ from .utils import register_custom_op , register_fake_op
20
23
21
24
_norm_module = None
22
25
@@ -43,7 +46,7 @@ def rmsnorm(
43
46
input : torch .Tensor ,
44
47
weight : torch .Tensor ,
45
48
eps : float = 1e-6 ,
46
- out : torch .Tensor = None ,
49
+ out : Optional [ torch .Tensor ] = None ,
47
50
) -> torch .Tensor :
48
51
r"""Root mean square normalization.
49
52
@@ -65,13 +68,28 @@ def rmsnorm(
65
68
"""
66
69
if out is None :
67
70
out = torch .empty_like (input )
68
- get_norm_module (). rmsnorm (out , input , weight , eps )
71
+ _rmsnorm (out , input , weight , eps )
69
72
return out
70
73
71
74
75
+ @register_custom_op ("flashinfer::rmsnorm" , mutates_args = ("out" ,))
76
+ def _rmsnorm (
77
+ out : torch .Tensor , input : torch .Tensor , weight : torch .Tensor , eps : float
78
+ ) -> None :
79
+ get_norm_module ().rmsnorm (out , input , weight , eps )
80
+
81
+
82
+ @register_fake_op ("flashinfer::rmsnorm" )
83
+ def _rmsnorm_fake (
84
+ out : torch .Tensor , input : torch .Tensor , weight : torch .Tensor , eps : float
85
+ ) -> None :
86
+ pass
87
+
88
+
89
+ @register_custom_op ("flashinfer::fused_add_rmsnorm" , mutates_args = ("input" , "residual" ))
72
90
def fused_add_rmsnorm (
73
91
input : torch .Tensor , residual : torch .Tensor , weight : torch .Tensor , eps : float = 1e-6
74
- ):
92
+ ) -> None :
75
93
r"""Fused add root mean square normalization.
76
94
77
95
Parameters
@@ -88,12 +106,19 @@ def fused_add_rmsnorm(
88
106
get_norm_module ().fused_add_rmsnorm (input , residual , weight , eps )
89
107
90
108
109
+ @register_fake_op ("flashinfer::fused_add_rmsnorm" )
110
+ def _fused_add_rmsnorm_fake (
111
+ input : torch .Tensor , residual : torch .Tensor , weight : torch .Tensor , eps : float = 1e-6
112
+ ) -> None :
113
+ pass
114
+
115
+
91
116
def gemma_rmsnorm (
92
117
input : torch .Tensor ,
93
118
weight : torch .Tensor ,
94
119
eps : float = 1e-6 ,
95
- out : torch .Tensor = None ,
96
- ):
120
+ out : Optional [ torch .Tensor ] = None ,
121
+ ) -> torch . Tensor :
97
122
r"""Gemma Root mean square normalization.
98
123
99
124
Parameters
@@ -114,13 +139,30 @@ def gemma_rmsnorm(
114
139
"""
115
140
if out is None :
116
141
out = torch .empty_like (input )
117
- get_norm_module (). gemma_rmsnorm (out , input , weight , eps )
142
+ _gemma_rmsnorm (out , input , weight , eps )
118
143
return out
119
144
120
145
146
+ @register_custom_op ("flashinfer::gemma_rmsnorm" , mutates_args = ("out" ,))
147
+ def _gemma_rmsnorm (
148
+ out : torch .Tensor , input : torch .Tensor , weight : torch .Tensor , eps : float
149
+ ) -> None :
150
+ get_norm_module ().gemma_rmsnorm (out , input , weight , eps )
151
+
152
+
153
+ @register_fake_op ("flashinfer::gemma_rmsnorm" )
154
+ def _gemma_rmsnorm_fake (
155
+ out : torch .Tensor , input : torch .Tensor , weight : torch .Tensor , eps : float
156
+ ) -> None :
157
+ pass
158
+
159
+
160
+ @register_custom_op (
161
+ "flashinfer::gemma_fused_add_rmsnorm" , mutates_args = ("input" , "residual" )
162
+ )
121
163
def gemma_fused_add_rmsnorm (
122
164
input : torch .Tensor , residual : torch .Tensor , weight : torch .Tensor , eps : float = 1e-6
123
- ):
165
+ ) -> None :
124
166
r"""Gemma Fused add root mean square normalization.
125
167
126
168
Parameters
@@ -135,3 +177,10 @@ def gemma_fused_add_rmsnorm(
135
177
Epsilon for numerical stability.
136
178
"""
137
179
get_norm_module ().gemma_fused_add_rmsnorm (input , residual , weight , eps )
180
+
181
+
182
+ @register_fake_op ("flashinfer::gemma_fused_add_rmsnorm" )
183
+ def _gemma_fused_add_rmsnorm_fake (
184
+ input : torch .Tensor , residual : torch .Tensor , weight : torch .Tensor , eps : float = 1e-6
185
+ ) -> None :
186
+ pass
0 commit comments