|
| 1 | +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | + |
| 3 | +#include <ATen/cuda/CUDAContext.h> |
| 4 | +#include <c10/cuda/CUDAGuard.h> |
| 5 | +#include <torch/extension.h> |
| 6 | +#include <cmath> |
| 7 | +#include <vector> |
| 8 | + |
| 9 | +template <typename scalar_t> |
| 10 | +__global__ void SigmoidAlphaBlendForwardKernel( |
| 11 | + // clang-format off |
| 12 | + const torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> distances, // (N, H, W, K) |
| 13 | + const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> pix_to_face, // (N, H, W, K) |
| 14 | + torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> alphas, // (N, H, W) |
| 15 | + // clang-format on |
| 16 | + const scalar_t sigma, |
| 17 | + const int N, |
| 18 | + const int H, |
| 19 | + const int W, |
| 20 | + const int K) { |
| 21 | + // Parallelize over each pixel in images of |
| 22 | + // size H * W, for each image in the batch of size N. |
| 23 | + const int num_threads = gridDim.x * blockDim.x; |
| 24 | + const int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 25 | + |
| 26 | + // TODO: revisit performance of this kernel with shared memory usage |
| 27 | + |
| 28 | + for (int t_i = tid; t_i < N * H * W; t_i += num_threads) { |
| 29 | + // Convert linear index to 3D index |
| 30 | + const int n = t_i / (H * W); // batch index. |
| 31 | + const int pix_idx = t_i % (H * W); |
| 32 | + |
| 33 | + // TODO: fix index calculation for non square images. |
| 34 | + const int yi = pix_idx / W; |
| 35 | + const int xi = pix_idx % W; |
| 36 | + scalar_t alpha = 1.0; |
| 37 | + |
| 38 | + // Loop over all the faces for this pixel. |
| 39 | + for (int k = 0; k < K; k++) { |
| 40 | + // Index into (N, H, W, K) tensors |
| 41 | + const int f = pix_to_face[n][yi][xi][k]; |
| 42 | + if (f < 0) { |
| 43 | + // Sentinel value is -1 indicating no face overlaps the pixel. |
| 44 | + continue; |
| 45 | + } |
| 46 | + // The distance is negative if a pixel is inside a face and positive |
| 47 | + // outside the face. Therefore use -1.0 * the distance to get the |
| 48 | + // correct sign. |
| 49 | + scalar_t dist = -1.0 * distances[n][yi][xi][k]; |
| 50 | + |
| 51 | + // Calculate the sigmoid probability. |
| 52 | + scalar_t prob = 1. / (1. + exp(-dist / sigma)); |
| 53 | + |
| 54 | + // The cumulative product ensures that alpha will be 0.0 if at least 1 |
| 55 | + // face fully covers the pixel as for that face, prob will be 1.0. |
| 56 | + // This results in a multiplication by 0.0 because of the (1.0 - prob) |
| 57 | + // term. Therefore the final result of (1.0 - alpha) will be 1.0. |
| 58 | + alpha *= (1.0 - prob); |
| 59 | + } |
| 60 | + alphas[n][yi][xi] = 1.0 - alpha; |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +torch::Tensor SigmoidAlphaBlendForwardCuda( |
| 65 | + const at::Tensor& distances, // (N, H, W, K) |
| 66 | + const at::Tensor& pix_to_face, // (N, H, W, K) |
| 67 | + const float sigma) { |
| 68 | + const int N = distances.size(0); |
| 69 | + const int H = distances.size(1); |
| 70 | + const int W = distances.size(2); |
| 71 | + const int K = distances.size(3); |
| 72 | + |
| 73 | + at::Tensor alphas = at::zeros({N, H, W}, distances.options()); |
| 74 | + const size_t blocks = 1024; |
| 75 | + const size_t threads = 128; |
| 76 | + |
| 77 | + // Check inputs are on the same device |
| 78 | + at::TensorArg distances_t{distances, "distances", 1}, |
| 79 | + pix_to_face_t{pix_to_face, "pix_to_face", 2}; |
| 80 | + at::CheckedFrom c = "SigmoidAlphaBlendForwardCuda"; |
| 81 | + at::checkAllSameGPU(c, {distances_t, pix_to_face_t}); |
| 82 | + |
| 83 | + // Set the device for the kernel launch based on the device of distances |
| 84 | + at::cuda::CUDAGuard device_guard(distances.device()); |
| 85 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 86 | + |
| 87 | + if (distances.numel() == 0) { |
| 88 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 89 | + return alphas; |
| 90 | + } |
| 91 | + |
| 92 | + AT_DISPATCH_FLOATING_TYPES( |
| 93 | + distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] { |
| 94 | + // clang-format off |
| 95 | + SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>( |
| 96 | + distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(), |
| 97 | + pix_to_face.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>(), |
| 98 | + alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(), |
| 99 | + sigma, |
| 100 | + N, |
| 101 | + H, |
| 102 | + W, |
| 103 | + K); |
| 104 | + // clang-format on |
| 105 | + })); |
| 106 | + |
| 107 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 108 | + return alphas; |
| 109 | +} |
| 110 | + |
| 111 | +template <typename scalar_t> |
| 112 | +__global__ void SigmoidAlphaBlendBackwardKernel( |
| 113 | + // clang-format off |
| 114 | + const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> grad_alphas, // (N, H, W) |
| 115 | + const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> alphas, // (N, H, W) |
| 116 | + const torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> distances, // (N, H, W, K) |
| 117 | + const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> pix_to_face, // (N, H, W, K) |
| 118 | + torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> grad_distances, // (N, H, W) |
| 119 | + // clang-format on |
| 120 | + const scalar_t sigma, |
| 121 | + const int N, |
| 122 | + const int H, |
| 123 | + const int W, |
| 124 | + const int K) { |
| 125 | + // Parallelize over each of the top K faces for each pixel in images of |
| 126 | + // size H * W * K, for each image in the batch of size N. |
| 127 | + |
| 128 | + // Get block and thread index. |
| 129 | + const int n = blockIdx.x; |
| 130 | + const int num_pixels = H * W * K; |
| 131 | + const int num_threads = gridDim.y * blockDim.x; |
| 132 | + const int tid = blockIdx.y * blockDim.x + threadIdx.x; |
| 133 | + |
| 134 | + for (int t_i = tid; t_i < num_pixels; t_i += num_threads) { |
| 135 | + // Convert linear index to 3D index. |
| 136 | + int yi = t_i / (W * K); |
| 137 | + int xi = (t_i % (W * K)) / K; |
| 138 | + int k = (t_i % (W * K)) % K; |
| 139 | + |
| 140 | + const scalar_t alpha = 1.0 - alphas[n][yi][xi]; |
| 141 | + const scalar_t grad_alpha = grad_alphas[n][yi][xi]; |
| 142 | + const int f = pix_to_face[n][yi][xi][k]; |
| 143 | + |
| 144 | + // Sentinel value is -1 indicating no face overlaps the pixel. |
| 145 | + if (f >= 0) { |
| 146 | + // The distance is negative if a pixel is inside a face and positive |
| 147 | + // outside the face. Therefore use -1.0 * the distance to get the |
| 148 | + // correct sign. |
| 149 | + scalar_t dist = -1.0 * distances[n][yi][xi][k]; |
| 150 | + |
| 151 | + // Calculate the sigmoid probability. |
| 152 | + scalar_t prob = 1. / (1. + exp(-dist / sigma)); |
| 153 | + |
| 154 | + grad_distances[n][yi][xi][k] = grad_alpha * (-1.0 / sigma) * prob * alpha; |
| 155 | + } |
| 156 | + } |
| 157 | +} |
| 158 | + |
| 159 | +torch::Tensor SigmoidAlphaBlendBackwardCuda( |
| 160 | + const at::Tensor& grad_alphas, // (N, H, W) |
| 161 | + const at::Tensor& alphas, // (N, H, W) |
| 162 | + const at::Tensor& distances, // (N, H, W, K) |
| 163 | + const at::Tensor& pix_to_face, // (N, H, W, K) |
| 164 | + float sigma) { |
| 165 | + const int N = distances.size(0); |
| 166 | + const int H = distances.size(1); |
| 167 | + const int W = distances.size(2); |
| 168 | + const int K = distances.size(3); |
| 169 | + |
| 170 | + at::Tensor grad_distances = at::zeros({N, H, W, K}, distances.options()); |
| 171 | + |
| 172 | + const dim3 threads(512); |
| 173 | + const dim3 blocks(N, 1024 / N + 1); |
| 174 | + |
| 175 | + at::TensorArg grad_alphas_t{grad_alphas, "grad_alphas", 1}, |
| 176 | + alphas_t{alphas, "alphas", 2}, distances_t{distances, "distances", 3}, |
| 177 | + pix_to_face_t{pix_to_face, "pix_to_face", 4}; |
| 178 | + at::CheckedFrom c = "SigmoidAlphaBlendBackwardCuda"; |
| 179 | + at::checkAllSameGPU(c, {grad_alphas_t, alphas_t, distances_t, pix_to_face_t}); |
| 180 | + |
| 181 | + // Set the device for the kernel launch based on the device of distances |
| 182 | + at::cuda::CUDAGuard device_guard(alphas.device()); |
| 183 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 184 | + |
| 185 | + if (alphas.numel() == 0) { |
| 186 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 187 | + return grad_alphas; |
| 188 | + } |
| 189 | + |
| 190 | + AT_DISPATCH_FLOATING_TYPES( |
| 191 | + distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] { |
| 192 | + SigmoidAlphaBlendBackwardKernel<scalar_t> |
| 193 | + <<<blocks, threads, 0, stream>>>( |
| 194 | + // clang-format off |
| 195 | + grad_alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(), |
| 196 | + alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(), |
| 197 | + distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(), |
| 198 | + pix_to_face.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>(), |
| 199 | + grad_distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(), |
| 200 | + // clang-format on |
| 201 | + sigma, |
| 202 | + N, |
| 203 | + H, |
| 204 | + W, |
| 205 | + K); |
| 206 | + })); |
| 207 | + |
| 208 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 209 | + return grad_distances; |
| 210 | +} |
0 commit comments