|
23 | 23 |
|
24 | 24 | #include <cstdint>
|
25 | 25 | #include <iostream>
|
26 |
| -#include <sstream> |
27 |
| -#include <stdexcept> |
28 | 26 | #include <vector>
|
29 | 27 |
|
| 28 | +#include "exception.h" |
| 29 | + |
30 | 30 | #define STR_HELPER(x) #x
|
31 | 31 | #define STR(x) STR_HELPER(x)
|
32 | 32 |
|
|
57 | 57 |
|
58 | 58 | #define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, ...) \
|
59 | 59 | if (allow_fp16_qk_reduction) { \
|
60 |
| - throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \ |
| 60 | + FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ |
61 | 61 | } else { \
|
62 | 62 | constexpr bool ALLOW_FP16_QK_REDUCTION = false; \
|
63 | 63 | __VA_ARGS__ \
|
|
73 | 73 | } else { \
|
74 | 74 | std::ostringstream err_msg; \
|
75 | 75 | err_msg << "Unsupported num_frags_q: " << num_frags_q; \
|
76 |
| - throw std::invalid_argument(err_msg.str()); \ |
| 76 | + FLASHINFER_ERROR(err_msg.str()); \ |
77 | 77 | }
|
78 | 78 |
|
79 | 79 | #define DISPATCH_NUM_FRAGS_KV(max_frags_kv, NUM_FRAGS_KV, ...) \
|
|
92 | 92 | } else { \
|
93 | 93 | std::ostringstream err_msg; \
|
94 | 94 | err_msg << "Unsupported max_frags_kv: " << max_frags_kv; \
|
95 |
| - throw std::invalid_argument(err_msg.str()); \ |
| 95 | + FLASHINFER_ERROR(err_msg.str()); \ |
96 | 96 | }
|
97 | 97 |
|
98 | 98 | #define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \
|
|
115 | 115 | default: { \
|
116 | 116 | std::ostringstream err_msg; \
|
117 | 117 | err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \
|
118 |
| - throw std::invalid_argument(err_msg.str()); \ |
| 118 | + FLASHINFER_ERROR(err_msg.str()); \ |
119 | 119 | } \
|
120 | 120 | }
|
121 | 121 |
|
|
138 | 138 | } else { \
|
139 | 139 | std::ostringstream err_msg; \
|
140 | 140 | err_msg << "Unsupported group_size: " << group_size; \
|
141 |
| - throw std::invalid_argument(err_msg.str()); \ |
| 141 | + FLASHINFER_ERROR(err_msg.str()); \ |
142 | 142 | }
|
143 | 143 |
|
144 | 144 | #define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \
|
|
161 | 161 | default: { \
|
162 | 162 | std::ostringstream err_msg; \
|
163 | 163 | err_msg << "Unsupported mask_mode: " << int(mask_mode); \
|
164 |
| - throw std::invalid_argument(err_msg.str()); \ |
| 164 | + FLASHINFER_ERROR(err_msg.str()); \ |
165 | 165 | } \
|
166 | 166 | }
|
167 | 167 |
|
|
190 | 190 | default: { \
|
191 | 191 | std::ostringstream err_msg; \
|
192 | 192 | err_msg << "Unsupported head_dim: " << head_dim; \
|
193 |
| - throw std::invalid_argument(err_msg.str()); \ |
| 193 | + FLASHINFER_ERROR(err_msg.str()); \ |
194 | 194 | } \
|
195 | 195 | }
|
196 | 196 |
|
|
214 | 214 | default: { \
|
215 | 215 | std::ostringstream err_msg; \
|
216 | 216 | err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
|
217 |
| - throw std::invalid_argument(err_msg.str()); \ |
| 217 | + FLASHINFER_ERROR(err_msg.str()); \ |
218 | 218 | } \
|
219 | 219 | }
|
220 | 220 |
|
|
248 | 248 | default: { \
|
249 | 249 | std::ostringstream err_msg; \
|
250 | 250 | err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
|
251 |
| - throw std::invalid_argument(err_msg.str()); \ |
| 251 | + FLASHINFER_ERROR(err_msg.str()); \ |
252 | 252 | } \
|
253 | 253 | }
|
254 | 254 |
|
|
0 commit comments