Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FBcode->GH] Fix missing kernel guards (#4620) #4622

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ void deformable_im2col(
int deformable_group,
bool use_mask,
at::Tensor data_col) {
int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs;
at::cuda::CUDAGuard device_guard(input.get_device());

const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs;

const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
Expand Down Expand Up @@ -408,12 +410,14 @@ void compute_grad_input(
int n_offset_grps,
bool use_mask,
at::Tensor grad_im) {
int out_h =
at::cuda::CUDAGuard device_guard(columns.get_device());

const int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
const int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;

int64_t num_kernels =
const int64_t num_kernels =
(int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs;

const unsigned int threads = GET_THREADS();
Expand Down Expand Up @@ -650,11 +654,13 @@ void compute_grad_offset_and_mask(
bool use_mask,
at::Tensor grad_offset,
at::Tensor grad_mask) {
int out_h =
at::cuda::CUDAGuard device_guard(columns.get_device());

const int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
const int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w *
const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w *
n_offset_grps * parallel_imgs;

const unsigned int threads = GET_THREADS();
Expand Down