Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp16 support for sample_farthest_points, ball_query and knn #1929

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
14 changes: 7 additions & 7 deletions pytorch3d/csrc/ball_query/ball_query.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ __global__ void BallQueryKernel(
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
const int64_t K,
const float radius2) {
const scalar_t radius2) {
const int64_t N = p1.size(0);
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
Expand Down Expand Up @@ -95,7 +95,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const int N = p1.size(0);
const int P1 = p1.size(1);
const int64_t K_64 = K;
const float radius2 = radius * radius;
const auto radius2 = radius * radius;

// Output tensor with indices of neighbors for each point in p1
auto long_dtype = lengths1.options().dtype(at::kLong);
Expand All @@ -110,15 +110,15 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const size_t blocks = 256;
const size_t threads = 256;

AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
p1.scalar_type(), "ball_query_kernel_cuda", ([&] {
BallQueryKernel<<<blocks, threads, 0, stream>>>(
p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
BallQueryKernel<scalar_t><<<blocks, threads, 0, stream>>>(
p1.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
p2.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
dists.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
K_64,
radius2);
}));
Expand Down
73 changes: 37 additions & 36 deletions pytorch3d/csrc/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Atomic.cuh>
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <iostream>
#include <tuple>

Expand Down Expand Up @@ -57,7 +57,7 @@ __global__ void KNearestNeighborKernelV0(
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
scalar_t diff = coord1 - coord2;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
scalar_t norm_diff = (norm == 2) ? (diff * diff) : ((diff > 0) ? diff : -diff);
dist += norm_diff;
}
mink.add(dist, p2);
Expand Down Expand Up @@ -102,7 +102,7 @@ __global__ void KNearestNeighborKernelV1(
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
scalar_t norm_diff = (norm == 2) ? (diff * diff) : ((diff > 0) ? diff : -diff);
dist += norm_diff;
}
mink.add(dist, p2);
Expand Down Expand Up @@ -167,7 +167,7 @@ __global__ void KNearestNeighborKernelV2(
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
scalar_t norm_diff = (norm == 2) ? (diff * diff) : ((diff > 0) ? diff : -diff);
dist += norm_diff;
}
mink.add(dist, p2);
Expand Down Expand Up @@ -238,7 +238,7 @@ __global__ void KNearestNeighborKernelV3(
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
scalar_t norm_diff = (norm == 2) ? (diff * diff) : ((diff > 0) ? diff : -diff);
dist += norm_diff;
}
mink.add(dist, p2);
Expand Down Expand Up @@ -367,7 +367,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const size_t threads = 256;
const size_t blocks = 256;
if (version == 0) {
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
p1.scalar_type(), "knn_kernel_cuda", ([&] {
KNearestNeighborKernelV0<scalar_t><<<blocks, threads, 0, stream>>>(
p1.contiguous().data_ptr<scalar_t>(),
Expand All @@ -384,7 +384,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
norm);
}));
} else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel1D<
KNearestNeighborV1Functor,
scalar_t,
Expand All @@ -406,7 +406,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
norm);
}));
} else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV2Functor,
scalar_t,
Expand All @@ -430,7 +430,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
norm);
}));
} else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(p1.scalar_type(), "knn_kernel_cuda", ([&] {
DispatchKernel2D<
KNearestNeighborKernelV3Functor,
scalar_t,
Expand Down Expand Up @@ -462,17 +462,16 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
// Backward Operators //
// ------------------------------------------------------------- //

// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
template <typename scalar_t>
__global__ void KNearestNeighborBackwardKernel(
const float* __restrict__ p1, // (N, P1, D)
const float* __restrict__ p2, // (N, P2, D)
const scalar_t* __restrict__ p1, // (N, P1, D)
const scalar_t* __restrict__ p2, // (N, P2, D)
const int64_t* __restrict__ lengths1, // (N,)
const int64_t* __restrict__ lengths2, // (N,)
const int64_t* __restrict__ idxs, // (N, P1, K)
const float* __restrict__ grad_dists, // (N, P1, K)
float* __restrict__ grad_p1, // (N, P1, D)
float* __restrict__ grad_p2, // (N, P2, D)
const scalar_t* __restrict__ grad_dists, // (N, P1, K)
scalar_t* __restrict__ grad_p1, // (N, P1, D)
scalar_t* __restrict__ grad_p2, // (N, P2, D)
const size_t N,
const size_t P1,
const size_t P2,
Expand All @@ -493,16 +492,16 @@ __global__ void KNearestNeighborBackwardKernel(
const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
if ((p1_idx < num1) && (k < num2)) {
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
const scalar_t grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
// index of point in p2 corresponding to the k-th nearest neighbor
const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
// If the index is the pad value of -1 then ignore it
if (p2_idx == -1) {
continue;
}
float diff = 0.0;
scalar_t diff = 0.0;
if (norm == 1) {
float sign =
scalar_t sign =
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
? 1.0
: -1.0;
Expand All @@ -511,8 +510,8 @@ __global__ void KNearestNeighborBackwardKernel(
diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
}
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
gpuAtomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
gpuAtomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
}
}
}
Expand Down Expand Up @@ -566,21 +565,23 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const int blocks = 64;
const int threads = 512;

KNearestNeighborBackwardKernel<<<blocks, threads, 0, stream>>>(
p1.contiguous().data_ptr<float>(),
p2.contiguous().data_ptr<float>(),
lengths1.contiguous().data_ptr<int64_t>(),
lengths2.contiguous().data_ptr<int64_t>(),
idxs.contiguous().data_ptr<int64_t>(),
grad_dists.contiguous().data_ptr<float>(),
grad_p1.data_ptr<float>(),
grad_p2.data_ptr<float>(),
N,
P1,
P2,
K,
D,
norm);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(p1.scalar_type(), "knn_backward_kernel_cuda", ([&] {
KNearestNeighborBackwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
p1.contiguous().data_ptr<scalar_t>(),
p2.contiguous().data_ptr<scalar_t>(),
lengths1.contiguous().data_ptr<int64_t>(),
lengths2.contiguous().data_ptr<int64_t>(),
idxs.contiguous().data_ptr<int64_t>(),
grad_dists.contiguous().data_ptr<scalar_t>(),
grad_p1.data_ptr<scalar_t>(),
grad_p2.data_ptr<scalar_t>(),
N,
P1,
P2,
K,
D,
norm);
}));

AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);
Expand Down
Loading