Skip to content

Commit bce396d

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
C++/CUDA implementation of sigmoid alpha blend
Summary: C++/CUDA implementation of forward and backward passes for the sigmoid alpha blending function. This is slightly faster than the vectorized implementation in Python, but more importantly uses less memory due to fewer tensors being created. Reviewed By: gkioxari Differential Revision: D19980671 fbshipit-source-id: 0779055d2c68b1f20fb0870e60046077ef4613ff
1 parent dc08c30 commit bce396d

File tree

8 files changed

+513
-55
lines changed

8 files changed

+513
-55
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <c10/cuda/CUDAGuard.h>
5+
#include <torch/extension.h>
6+
#include <cmath>
7+
#include <vector>
8+
9+
template <typename scalar_t>
10+
__global__ void SigmoidAlphaBlendForwardKernel(
11+
// clang-format off
12+
const torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> distances, // (N, H, W, K)
13+
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> pix_to_face, // (N, H, W, K)
14+
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> alphas, // (N, H, W)
15+
// clang-format on
16+
const scalar_t sigma,
17+
const int N,
18+
const int H,
19+
const int W,
20+
const int K) {
21+
// Parallelize over each pixel in images of
22+
// size H * W, for each image in the batch of size N.
23+
const int num_threads = gridDim.x * blockDim.x;
24+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
25+
26+
// TODO: revisit performance of this kernel with shared memory usage
27+
28+
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
29+
// Convert linear index to 3D index
30+
const int n = t_i / (H * W); // batch index.
31+
const int pix_idx = t_i % (H * W);
32+
33+
// TODO: fix index calculation for non square images.
34+
const int yi = pix_idx / W;
35+
const int xi = pix_idx % W;
36+
scalar_t alpha = 1.0;
37+
38+
// Loop over all the faces for this pixel.
39+
for (int k = 0; k < K; k++) {
40+
// Index into (N, H, W, K) tensors
41+
const int f = pix_to_face[n][yi][xi][k];
42+
if (f < 0) {
43+
// Sentinel value is -1 indicating no face overlaps the pixel.
44+
continue;
45+
}
46+
// The distance is negative if a pixel is inside a face and positive
47+
// outside the face. Therefore use -1.0 * the distance to get the
48+
// correct sign.
49+
scalar_t dist = -1.0 * distances[n][yi][xi][k];
50+
51+
// Calculate the sigmoid probability.
52+
scalar_t prob = 1. / (1. + exp(-dist / sigma));
53+
54+
// The cumulative product ensures that alpha will be 0.0 if at least 1
55+
// face fully covers the pixel as for that face, prob will be 1.0.
56+
// This results in a multiplication by 0.0 because of the (1.0 - prob)
57+
// term. Therefore the final result of (1.0 - alpha) will be 1.0.
58+
alpha *= (1.0 - prob);
59+
}
60+
alphas[n][yi][xi] = 1.0 - alpha;
61+
}
62+
}
63+
64+
torch::Tensor SigmoidAlphaBlendForwardCuda(
65+
const at::Tensor& distances, // (N, H, W, K)
66+
const at::Tensor& pix_to_face, // (N, H, W, K)
67+
const float sigma) {
68+
const int N = distances.size(0);
69+
const int H = distances.size(1);
70+
const int W = distances.size(2);
71+
const int K = distances.size(3);
72+
73+
at::Tensor alphas = at::zeros({N, H, W}, distances.options());
74+
const size_t blocks = 1024;
75+
const size_t threads = 128;
76+
77+
// Check inputs are on the same device
78+
at::TensorArg distances_t{distances, "distances", 1},
79+
pix_to_face_t{pix_to_face, "pix_to_face", 2};
80+
at::CheckedFrom c = "SigmoidAlphaBlendForwardCuda";
81+
at::checkAllSameGPU(c, {distances_t, pix_to_face_t});
82+
83+
// Set the device for the kernel launch based on the device of distances
84+
at::cuda::CUDAGuard device_guard(distances.device());
85+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
86+
87+
if (distances.numel() == 0) {
88+
AT_CUDA_CHECK(cudaGetLastError());
89+
return alphas;
90+
}
91+
92+
AT_DISPATCH_FLOATING_TYPES(
93+
distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] {
94+
// clang-format off
95+
SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
96+
distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
97+
pix_to_face.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>(),
98+
alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
99+
sigma,
100+
N,
101+
H,
102+
W,
103+
K);
104+
// clang-format on
105+
}));
106+
107+
AT_CUDA_CHECK(cudaGetLastError());
108+
return alphas;
109+
}
110+
111+
template <typename scalar_t>
112+
__global__ void SigmoidAlphaBlendBackwardKernel(
113+
// clang-format off
114+
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> grad_alphas, // (N, H, W)
115+
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> alphas, // (N, H, W)
116+
const torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> distances, // (N, H, W, K)
117+
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> pix_to_face, // (N, H, W, K)
118+
torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> grad_distances, // (N, H, W)
119+
// clang-format on
120+
const scalar_t sigma,
121+
const int N,
122+
const int H,
123+
const int W,
124+
const int K) {
125+
// Parallelize over each of the top K faces for each pixel in images of
126+
// size H * W * K, for each image in the batch of size N.
127+
128+
// Get block and thread index.
129+
const int n = blockIdx.x;
130+
const int num_pixels = H * W * K;
131+
const int num_threads = gridDim.y * blockDim.x;
132+
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
133+
134+
for (int t_i = tid; t_i < num_pixels; t_i += num_threads) {
135+
// Convert linear index to 3D index.
136+
int yi = t_i / (W * K);
137+
int xi = (t_i % (W * K)) / K;
138+
int k = (t_i % (W * K)) % K;
139+
140+
const scalar_t alpha = 1.0 - alphas[n][yi][xi];
141+
const scalar_t grad_alpha = grad_alphas[n][yi][xi];
142+
const int f = pix_to_face[n][yi][xi][k];
143+
144+
// Sentinel value is -1 indicating no face overlaps the pixel.
145+
if (f >= 0) {
146+
// The distance is negative if a pixel is inside a face and positive
147+
// outside the face. Therefore use -1.0 * the distance to get the
148+
// correct sign.
149+
scalar_t dist = -1.0 * distances[n][yi][xi][k];
150+
151+
// Calculate the sigmoid probability.
152+
scalar_t prob = 1. / (1. + exp(-dist / sigma));
153+
154+
grad_distances[n][yi][xi][k] = grad_alpha * (-1.0 / sigma) * prob * alpha;
155+
}
156+
}
157+
}
158+
159+
torch::Tensor SigmoidAlphaBlendBackwardCuda(
160+
const at::Tensor& grad_alphas, // (N, H, W)
161+
const at::Tensor& alphas, // (N, H, W)
162+
const at::Tensor& distances, // (N, H, W, K)
163+
const at::Tensor& pix_to_face, // (N, H, W, K)
164+
float sigma) {
165+
const int N = distances.size(0);
166+
const int H = distances.size(1);
167+
const int W = distances.size(2);
168+
const int K = distances.size(3);
169+
170+
at::Tensor grad_distances = at::zeros({N, H, W, K}, distances.options());
171+
172+
const dim3 threads(512);
173+
const dim3 blocks(N, 1024 / N + 1);
174+
175+
at::TensorArg grad_alphas_t{grad_alphas, "grad_alphas", 1},
176+
alphas_t{alphas, "alphas", 2}, distances_t{distances, "distances", 3},
177+
pix_to_face_t{pix_to_face, "pix_to_face", 4};
178+
at::CheckedFrom c = "SigmoidAlphaBlendBackwardCuda";
179+
at::checkAllSameGPU(c, {grad_alphas_t, alphas_t, distances_t, pix_to_face_t});
180+
181+
// Set the device for the kernel launch based on the device of distances
182+
at::cuda::CUDAGuard device_guard(alphas.device());
183+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
184+
185+
if (alphas.numel() == 0) {
186+
AT_CUDA_CHECK(cudaGetLastError());
187+
return grad_alphas;
188+
}
189+
190+
AT_DISPATCH_FLOATING_TYPES(
191+
distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] {
192+
SigmoidAlphaBlendBackwardKernel<scalar_t>
193+
<<<blocks, threads, 0, stream>>>(
194+
// clang-format off
195+
grad_alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
196+
alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
197+
distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
198+
pix_to_face.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>(),
199+
grad_distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
200+
// clang-format on
201+
sigma,
202+
N,
203+
H,
204+
W,
205+
K);
206+
}));
207+
208+
AT_CUDA_CHECK(cudaGetLastError());
209+
return grad_distances;
210+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#pragma once
4+
#include <torch/extension.h>
5+
#include <tuple>
6+
7+
// clang-format off
8+
// Function to blend the top K faces per pixel based on the 2d euclidean distance
9+
// from the center of the pixel to the face. This method is adapted from [1].
10+
// The output can be used to set the alpha value in an RGBA image.
11+
// Args:
12+
// pix_to_face: LongTensor of shape (N, H, W, K), indices of faces overlapping
13+
// with each pixel, where N is the batch size, H, W are the dimensions of the
14+
// image and K is the number of faces rasterized per pixel.
15+
// distances: FloatTensor of shape (N, H, W, K), 2d euclidean distance of each pixel
16+
// relative to the faces in pix_to_face
17+
// sigma: float, parameter which controls the width of the sigmoid for blending
18+
// Returns:
19+
// alphas: FloatTensor of shape (N, H, W), the blended values for each pixel
20+
// in the image.
21+
//
22+
// [1] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
23+
// Image-based 3D Reasoning'
24+
// clang-format on
25+
at::Tensor SigmoidAlphaBlendForwardCpu(
26+
const at::Tensor& distances,
27+
const at::Tensor& pix_to_face,
28+
const float sigma);
29+
30+
#ifdef WITH_CUDA
31+
at::Tensor SigmoidAlphaBlendForwardCuda(
32+
const at::Tensor& distances,
33+
const at::Tensor& pix_to_face,
34+
const float sigma);
35+
#endif
36+
37+
// clang-format off
38+
// Args:
39+
// grad_alphas: FloatTensor of shape (N, H, W), upstream gradients for alphas
40+
// alphas: FloatTensor of shape (N, H, W), the alpha values from the forward pass
41+
// pix_to_face: LongTensor of shape (N, H, W, K), indices of faces overlapping
42+
// with each pixel, where N is the batch size, H, W are the dimensions of the
43+
// image, and K is the number of faces rasterized per pixel
44+
// distances: FloatTensor of shape (N, H, W, K), 2d euclidean distance of each pixel
45+
// to the corresponding faces in pix_to_face
46+
// sigma: float, parameter which controls the width of the sigmoid for blending
47+
// Returns:
48+
// grad_distances: FloatTensor of shape (N, H, W, K)
49+
// clang-format on
50+
at::Tensor SigmoidAlphaBlendBackwardCpu(
51+
const at::Tensor& grad_alphas,
52+
const at::Tensor& alphas,
53+
const at::Tensor& distances,
54+
const at::Tensor& pix_to_face,
55+
const float sigma);
56+
57+
#ifdef WITH_CUDA
58+
at::Tensor SigmoidAlphaBlendBackwardCuda(
59+
const at::Tensor& grad_alphas,
60+
const at::Tensor& alphas,
61+
const at::Tensor& distances,
62+
const at::Tensor& pix_to_face,
63+
const float sigma);
64+
#endif
65+
66+
// Implementation which is exposed.
67+
at::Tensor
68+
SigmoidAlphaBlend(at::Tensor& distances, at::Tensor& pix_to_face, float sigma) {
69+
if (distances.is_cuda() && pix_to_face.is_cuda()) {
70+
#ifdef WITH_CUDA
71+
return SigmoidAlphaBlendForwardCuda(distances, pix_to_face, sigma);
72+
#else
73+
AT_ERROR("Not compiled with GPU support.");
74+
#endif
75+
}
76+
return SigmoidAlphaBlendForwardCpu(distances, pix_to_face, sigma);
77+
}
78+
79+
// Implementation which is exposed.
80+
at::Tensor SigmoidAlphaBlendBackward(
81+
const at::Tensor& grad_alphas,
82+
const at::Tensor& alphas,
83+
const at::Tensor& distances,
84+
const at::Tensor& pix_to_face,
85+
const float sigma) {
86+
if (distances.is_cuda() && pix_to_face.is_cuda() && alphas.is_cuda() &&
87+
grad_alphas.is_cuda()) {
88+
#ifdef WITH_CUDA
89+
return SigmoidAlphaBlendBackwardCuda(
90+
grad_alphas, alphas, distances, pix_to_face, sigma);
91+
#else
92+
AT_ERROR("Not compiled with GPU support.");
93+
#endif
94+
}
95+
return SigmoidAlphaBlendBackwardCpu(
96+
grad_alphas, alphas, distances, pix_to_face, sigma);
97+
}

0 commit comments

Comments
 (0)