Skip to content

Commit ca9b4ca

Browse files
q10facebook-github-bot
authored andcommitted
Enable NaN checks on tensor arguments to kernel launches (#4029)
Summary: X-link: facebookresearch/FBGEMM#1113 - Enable NaN checks on tensor arguments to kernel launches Differential Revision: D73698678
1 parent eeee38e commit ca9b4ca

File tree

4 files changed

+225
-32
lines changed

4 files changed

+225
-32
lines changed

fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

+74-29
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,35 @@ decltype(auto) transform_kernel_arg(const SourceContext& context, T&& arg) {
6262
}
6363
}
6464

65+
////////////////////////////////////////////////////////////////////////////////
66+
// Verify Kernel Argument
67+
//
68+
// Verify certain arguments before and after kernel invocation
69+
////////////////////////////////////////////////////////////////////////////////
70+
71+
template <typename T>
72+
decltype(auto) check_kernel_arg(const SourceContext& context, T&& arg) {
73+
if constexpr (is_tensor_accessor_builder_v<std::decay_t<T>>) {
74+
// If the arg is a TensorAccessorBuilder, run verifications on the tensor it
75+
// is ref-wrapping, e.g. NaN value checks.
76+
return arg.checkValues(context.description());
77+
} else {
78+
// Otherwise, perfect-forward the argument as is
79+
return std::forward<T>(arg);
80+
}
81+
}
82+
6583
////////////////////////////////////////////////////////////////////////////////
6684
// GPU Kernel Launcher
6785
//
6886
// This class encapsulates the common ceremonial pre- and post-execution
6987
// routines when launching GPU kernels.
7088
////////////////////////////////////////////////////////////////////////////////
7189

72-
template <bool EnableDSA = false, bool EnableBarrierIsolation = false>
90+
template <
91+
bool EnableDSA = false,
92+
bool EnableBarrierIsolation = false,
93+
bool EnableNaNChecks = false>
7394
struct KernelLauncher {
7495
const SourceContext context;
7596

@@ -223,6 +244,21 @@ struct KernelLauncher {
223244
// device associated with the compute stream
224245
checkSharedMemoryPerBlockNotExceeded(properties, shared_mem_per_block);
225246

247+
// If NaN checks are enabled, run verifications on all kernel arguments that
248+
// are tensors
249+
if constexpr (EnableNaNChecks) {
250+
const auto summary = std::string(context.summary) + " (pre-execution)";
251+
(check_kernel_arg(context.withSummary(summary), std::forward<Args>(args)),
252+
...);
253+
}
254+
255+
// If barrier isolation is enabled, synchronize the stream first before
256+
// launching the kernel. This has roughly the same effect as setting
257+
// `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
258+
if constexpr (EnableBarrierIsolation) {
259+
cudaDeviceSynchronize();
260+
}
261+
226262
if constexpr (EnableDSA) {
227263
// This launch code here is essentially the same as the contents of
228264
// TORCH_USE_CUDA_DSA macro, but with the addition of kernel argument
@@ -240,13 +276,6 @@ struct KernelLauncher {
240276
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref();
241277
#endif
242278

243-
// If barrier isolation is enabled, synchronize the stream first before
244-
// launching the kernel. This has roughly the same effect as setting
245-
// `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
246-
if constexpr (EnableBarrierIsolation) {
247-
cudaDeviceSynchronize();
248-
}
249-
250279
// Launch the kernel
251280
kernel<<<grid, block, shared_mem_per_block, stream>>>(
252281
// Transform arguments to the kernel before forwarding them.
@@ -274,6 +303,14 @@ struct KernelLauncher {
274303

275304
// Check for CUDA errors
276305
C10_CUDA_KERNEL_LAUNCH_CHECK();
306+
307+
// If NaN checks are enabled, run post-kernel verifications on all kernel
308+
// arguments that are tensors
309+
if constexpr (EnableNaNChecks) {
310+
const auto summary = std::string(context.summary) + " (post-execution)";
311+
(check_kernel_arg(context.withSummary(summary), std::forward<Args>(args)),
312+
...);
313+
}
277314
}
278315
};
279316

@@ -309,30 +346,38 @@ struct KernelLauncher {
309346
#define _FKL_TFILE_ ""
310347
#endif
311348

312-
#ifdef FBGEMM_GPU_KERNEL_DEBUG
313-
#define _FKL_KDEBUG_ true
349+
#ifdef FBGEMM_GPU_ISOLATE_KERNEL_LAUNCH
350+
#define _FKL_BLOCKING_ true
351+
#else
352+
#define _FKL_BLOCKING_ false
353+
#endif
354+
355+
#ifdef FBGEMM_GPU_TENSORCHECK
356+
#define _FKL_TENSORCHECK_ true
314357
#else
315-
#define _FKL_KDEBUG_ false
358+
#define _FKL_TENSORCHECK_ false
316359
#endif
317360

318-
#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
319-
([&] { \
320-
using source_location = fbgemm_gpu::utils::source_location; \
321-
constexpr auto location = source_location::current(); \
322-
decltype(KERNEL)& kernel = KERNEL; \
323-
\
324-
return fbgemm_gpu::utils::KernelLauncher<false, _FKL_KDEBUG_>( \
325-
location, #KERNEL, _FKL_TFILE_) \
326-
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
361+
#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
362+
([&] { \
363+
using source_location = fbgemm_gpu::utils::source_location; \
364+
constexpr auto location = source_location::current(); \
365+
decltype(KERNEL)& kernel = KERNEL; \
366+
\
367+
return fbgemm_gpu::utils:: \
368+
KernelLauncher<false, _FKL_BLOCKING_, _FKL_TENSORCHECK_>( \
369+
location, #KERNEL, _FKL_TFILE_) \
370+
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
327371
}())
328372

329-
#define FBGEMM_LAUNCH_DSA_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
330-
([&] { \
331-
using source_location = fbgemm_gpu::utils::source_location; \
332-
constexpr auto location = source_location::current(); \
333-
decltype(KERNEL)& kernel = KERNEL; \
334-
\
335-
return fbgemm_gpu::utils::KernelLauncher<true, _FKL_KDEBUG_>( \
336-
location, #KERNEL, _FKL_TFILE_) \
337-
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
373+
#define FBGEMM_LAUNCH_DSA_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
374+
([&] { \
375+
using source_location = fbgemm_gpu::utils::source_location; \
376+
constexpr auto location = source_location::current(); \
377+
decltype(KERNEL)& kernel = KERNEL; \
378+
\
379+
return fbgemm_gpu::utils:: \
380+
KernelLauncher<true, _FKL_BLOCKING_, _FKL_TENSORCHECK_>( \
381+
location, #KERNEL, _FKL_TFILE_) \
382+
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
338383
}())

fbgemm_gpu/include/fbgemm_gpu/utils/source_context.h

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ struct SourceContext {
8080

8181
return *desc_;
8282
}
83+
84+
inline SourceContext withSummary(
85+
const std::string_view& sum_) const noexcept {
86+
return SourceContext(location, sum_, secondaryLocation);
87+
}
8388
};
8489

8590
} // namespace fbgemm_gpu::utils

fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor_builder.h

+20
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,26 @@ struct TensorAccessorBuilder {
220220
return build_ta(context);
221221
}
222222
}
223+
224+
//////////////////////////////////////////////////////////////////////////////
225+
// Check Tensor values for NaN
226+
//////////////////////////////////////////////////////////////////////////////
227+
228+
C10_ALWAYS_INLINE void checkValues(const std::string_view& context) const {
229+
TORCH_CHECK(
230+
!at::isnan(tensor).any().item<bool>(),
231+
context,
232+
": Tensor '",
233+
name,
234+
"' contains NaN values!");
235+
236+
TORCH_CHECK(
237+
!at::isinf(tensor).any().item<bool>(),
238+
context,
239+
": Tensor '",
240+
name,
241+
"' contains (+/-) Inf values!");
242+
}
223243
};
224244

225245
} // namespace fbgemm_gpu::utils

fbgemm_gpu/test/utils/kernel_launcher_test.cu

+126-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
// FBGEMM codebase to denote the template source file in auto-generated code.
1111
#define __TEMPLATE_SOURCE_FILE__ "FOO/BAR/BAZ-123.cpp"
1212

13+
// Enable tensor value checking before and after executing kernels
14+
#define FBGEMM_GPU_TENSORCHECK
15+
1316
#include <ATen/ATen.h>
1417
#include <c10/cuda/CUDADeviceAssertion.h>
1518
#include <cuda.h>
@@ -71,6 +74,44 @@ __global__ void tensor_sum_kernel(
7174
}
7275
}
7376

77+
__device__ unsigned int xor128_rand_int(uint32_t seed) {
78+
auto x = seed ^ (blockIdx.x * blockDim.x + threadIdx.x);
79+
x ^= x << 13;
80+
x ^= x >> 17;
81+
x ^= x << 5;
82+
return x;
83+
}
84+
85+
template <typename T>
86+
__global__ void tensor_sum_kernel_bad_output(
87+
pta::PackedTensorAccessor64<T, 1, at::RestrictPtrTraits> C,
88+
const pta::PackedTensorAccessor64<T, 1, at::RestrictPtrTraits> A,
89+
const pta::PackedTensorAccessor64<T, 1, at::RestrictPtrTraits> B,
90+
TORCH_DSA_KERNEL_ARGS) {
91+
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
92+
auto seed = xor128_rand_int(42);
93+
94+
if (idx < C.size(0)) {
95+
if (seed = xor128_rand_int(seed); seed % 100 != 0) {
96+
// 99% chance of normal value
97+
C[idx] = A[idx] + B[idx];
98+
99+
} else {
100+
seed = xor128_rand_int(seed);
101+
102+
if (seed % 3 == 0) {
103+
C[idx] = std::numeric_limits<T>::quiet_NaN();
104+
105+
} else if (seed % 3 == 1) {
106+
C[idx] = std::numeric_limits<T>::infinity();
107+
108+
} else {
109+
C[idx] = std::numeric_limits<T>::infinity();
110+
}
111+
}
112+
}
113+
}
114+
74115
__global__ void always_fail_assertion_kernel(
75116
const int a,
76117
TORCH_DSA_KERNEL_ARGS) {
@@ -197,7 +238,7 @@ TEST(KernelLauncherTest, array_kernel_launch_dsa) {
197238
});
198239
}
199240

200-
TEST(KernelLauncherTest, tensor_array_kernel_launch) {
241+
TEST(KernelLauncherTest, tensor_kernel_launch) {
201242
const auto size = 1024;
202243
// Not using structured bindings bc it fails on ROCm with:
203244
// `capturing a structured binding is not yet supported in OpenMP`
@@ -276,8 +317,8 @@ TEST(KernelLauncherTest, kernel_launch_checks) {
276317
{
277318
FBGEMM_LAUNCH_DSA_KERNEL(
278319
tensor_sum_kernel<float>,
279-
// Both grid and block dims conform, but the total number of threads
280-
// exceeds the max
320+
// Both grid and block dims conform, but the total number of
321+
// threads exceeds the max
281322
{U32(grid_max[0]), U32(grid_max[1]), U32(grid_max[2])},
282323
{U32(block_max[0]), U32(block_max[1]), U32(block_max[2])},
283324
0,
@@ -306,6 +347,88 @@ TEST(KernelLauncherTest, kernel_launch_checks) {
306347
std::exception);
307348
}
308349

350+
TEST(KernelLauncherTest, tensor_value_checks) {
351+
const auto size = 1024;
352+
// Not using structured bindings bc it fails on ROCm with:
353+
// `capturing a structured binding is not yet supported in OpenMP`
354+
at::Tensor A, B, C;
355+
std::tie(A, B, C) = sample_tensors(size);
356+
357+
{
358+
// Test for bad INPUT tensors
359+
360+
const float values[] = {
361+
std::numeric_limits<float>::quiet_NaN(),
362+
std::numeric_limits<float>::infinity(),
363+
-std::numeric_limits<float>::infinity(),
364+
};
365+
366+
for (const auto value : values) {
367+
// Set a bad value
368+
auto i = rand() % size;
369+
A[i] = value;
370+
371+
EXPECT_THROW(
372+
{
373+
FBGEMM_LAUNCH_DSA_KERNEL(
374+
tensor_sum_kernel<float>,
375+
8,
376+
1024,
377+
0,
378+
at::cuda::getCurrentCUDAStream(),
379+
PTA_B(C, float, 1, 64),
380+
PTA_B(A, float, 1, 64),
381+
PTA_B(B, float, 1, 64));
382+
},
383+
std::exception);
384+
385+
// Unset the bad value
386+
A[i] = 1;
387+
}
388+
389+
for (const auto value : values) {
390+
// Set a bad value
391+
auto i = rand() % size;
392+
B[i] = value;
393+
394+
EXPECT_THROW(
395+
{
396+
FBGEMM_LAUNCH_DSA_KERNEL(
397+
tensor_sum_kernel<float>,
398+
8,
399+
1024,
400+
0,
401+
at::cuda::getCurrentCUDAStream(),
402+
PTA_B(C, float, 1, 64),
403+
PTA_B(A, float, 1, 64),
404+
PTA_B(B, float, 1, 64));
405+
},
406+
std::exception);
407+
408+
// Unset the bad value
409+
B[i] = 1;
410+
}
411+
}
412+
413+
{
414+
// Test for bad OUTPUT tensors
415+
416+
EXPECT_THROW(
417+
{
418+
FBGEMM_LAUNCH_DSA_KERNEL(
419+
tensor_sum_kernel_bad_output<float>,
420+
8,
421+
1024,
422+
0,
423+
at::cuda::getCurrentCUDAStream(),
424+
PTA_B(C, float, 1, 64),
425+
PTA_B(A, float, 1, 64),
426+
PTA_B(B, float, 1, 64));
427+
},
428+
std::exception);
429+
}
430+
}
431+
309432
// NOTE: This test currently fails in fbcode CI for HIP with the following
310433
// error (but runs without issues on both NVIDIA and AMD machines):
311434
//

0 commit comments

Comments
 (0)