@@ -62,14 +62,35 @@ decltype(auto) transform_kernel_arg(const SourceContext& context, T&& arg) {
62
62
}
63
63
}
64
64
65
+ // //////////////////////////////////////////////////////////////////////////////
66
+ // Verify Kernel Argument
67
+ //
68
+ // Verify certain arguments before and after kernel invocation
69
+ // //////////////////////////////////////////////////////////////////////////////
70
+
71
+ template <typename T>
72
+ decltype (auto ) check_kernel_arg(const SourceContext& context, T&& arg) {
73
+ if constexpr (is_tensor_accessor_builder_v<std::decay_t <T>>) {
74
+ // If the arg is a TensorAccessorBuilder, run verifications on the tensor it
75
+ // is ref-wrapping, e.g. NaN value checks.
76
+ return arg.checkValues (context.description ());
77
+ } else {
78
+ // Otherwise, perfect-forward the argument as is
79
+ return std::forward<T>(arg);
80
+ }
81
+ }
82
+
65
83
// //////////////////////////////////////////////////////////////////////////////
66
84
// GPU Kernel Launcher
67
85
//
68
86
// This class encapsulates the common ceremonial pre- and post-execution
69
87
// routines when launching GPU kernels.
70
88
// //////////////////////////////////////////////////////////////////////////////
71
89
72
- template <bool EnableDSA = false , bool EnableBarrierIsolation = false >
90
+ template <
91
+ bool EnableDSA = false ,
92
+ bool EnableBarrierIsolation = false ,
93
+ bool EnableNaNChecks = false >
73
94
struct KernelLauncher {
74
95
const SourceContext context;
75
96
@@ -223,6 +244,21 @@ struct KernelLauncher {
223
244
// device associated with the compute stream
224
245
checkSharedMemoryPerBlockNotExceeded (properties, shared_mem_per_block);
225
246
247
+ // If NaN checks are enabled, run verifications on all kernel arguments that
248
+ // are tensors
249
+ if constexpr (EnableNaNChecks) {
250
+ const auto summary = std::string (context.summary ) + " (pre-execution)" ;
251
+ (check_kernel_arg (context.withSummary (summary), std::forward<Args>(args)),
252
+ ...);
253
+ }
254
+
255
+ // If barrier isolation is enabled, synchronize the stream first before
256
+ // launching the kernel. This has roughly the same effect as setting
257
+ // `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
258
+ if constexpr (EnableBarrierIsolation) {
259
+ cudaDeviceSynchronize ();
260
+ }
261
+
226
262
if constexpr (EnableDSA) {
227
263
// This launch code here is essentially the same as the contents of
228
264
// TORCH_USE_CUDA_DSA macro, but with the addition of kernel argument
@@ -240,13 +276,6 @@ struct KernelLauncher {
240
276
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref ();
241
277
#endif
242
278
243
- // If barrier isolation is enabled, synchronize the stream first before
244
- // launching the kernel. This has roughly the same effect as setting
245
- // `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
246
- if constexpr (EnableBarrierIsolation) {
247
- cudaDeviceSynchronize ();
248
- }
249
-
250
279
// Launch the kernel
251
280
kernel<<<grid, block, shared_mem_per_block, stream>>> (
252
281
// Transform arguments to the kernel before forwarding them.
@@ -274,6 +303,14 @@ struct KernelLauncher {
274
303
275
304
// Check for CUDA errors
276
305
C10_CUDA_KERNEL_LAUNCH_CHECK ();
306
+
307
+ // If NaN checks are enabled, run post-kernel verifications on all kernel
308
+ // arguments that are tensors
309
+ if constexpr (EnableNaNChecks) {
310
+ const auto summary = std::string (context.summary ) + " (post-execution)" ;
311
+ (check_kernel_arg (context.withSummary (summary), std::forward<Args>(args)),
312
+ ...);
313
+ }
277
314
}
278
315
};
279
316
@@ -309,30 +346,38 @@ struct KernelLauncher {
309
346
#define _FKL_TFILE_ " "
310
347
#endif
311
348
312
- #ifdef FBGEMM_GPU_KERNEL_DEBUG
313
- #define _FKL_KDEBUG_ true
349
+ #ifdef FBGEMM_GPU_ISOLATE_KERNEL_LAUNCH
350
+ #define _FKL_BLOCKING_ true
351
+ #else
352
+ #define _FKL_BLOCKING_ false
353
+ #endif
354
+
355
+ #ifdef FBGEMM_GPU_TENSORCHECK
356
+ #define _FKL_TENSORCHECK_ true
314
357
#else
315
- #define _FKL_KDEBUG_ false
358
+ #define _FKL_TENSORCHECK_ false
316
359
#endif
317
360
318
- #define FBGEMM_LAUNCH_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
319
- ([&] { \
320
- using source_location = fbgemm_gpu::utils::source_location; \
321
- constexpr auto location = source_location::current (); \
322
- decltype (KERNEL)& kernel = KERNEL; \
323
- \
324
- return fbgemm_gpu::utils::KernelLauncher<false , _FKL_KDEBUG_>( \
325
- location, #KERNEL, _FKL_TFILE_) \
326
- .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
361
+ #define FBGEMM_LAUNCH_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
362
+ ([&] { \
363
+ using source_location = fbgemm_gpu::utils::source_location; \
364
+ constexpr auto location = source_location::current (); \
365
+ decltype (KERNEL)& kernel = KERNEL; \
366
+ \
367
+ return fbgemm_gpu::utils:: \
368
+ KernelLauncher<false , _FKL_BLOCKING_, _FKL_TENSORCHECK_>( \
369
+ location, #KERNEL, _FKL_TFILE_) \
370
+ .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
327
371
}())
328
372
329
- #define FBGEMM_LAUNCH_DSA_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
330
- ([&] { \
331
- using source_location = fbgemm_gpu::utils::source_location; \
332
- constexpr auto location = source_location::current (); \
333
- decltype (KERNEL)& kernel = KERNEL; \
334
- \
335
- return fbgemm_gpu::utils::KernelLauncher<true , _FKL_KDEBUG_>( \
336
- location, #KERNEL, _FKL_TFILE_) \
337
- .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
373
+ #define FBGEMM_LAUNCH_DSA_KERNEL (KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
374
+ ([&] { \
375
+ using source_location = fbgemm_gpu::utils::source_location; \
376
+ constexpr auto location = source_location::current (); \
377
+ decltype (KERNEL)& kernel = KERNEL; \
378
+ \
379
+ return fbgemm_gpu::utils:: \
380
+ KernelLauncher<true , _FKL_BLOCKING_, _FKL_TENSORCHECK_>( \
381
+ location, #KERNEL, _FKL_TFILE_) \
382
+ .launch_kernel (kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
338
383
}())
0 commit comments