diff --git a/pytorch3d/csrc/ball_query/ball_query.cu b/pytorch3d/csrc/ball_query/ball_query.cu index 586701c18..b279517ef 100644 --- a/pytorch3d/csrc/ball_query/ball_query.cu +++ b/pytorch3d/csrc/ball_query/ball_query.cu @@ -32,7 +32,7 @@ __global__ void BallQueryKernel( at::PackedTensorAccessor64 idxs, at::PackedTensorAccessor64 dists, const int64_t K, - const float radius2) { + const scalar_t radius2) { const int64_t N = p1.size(0); const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x); const int64_t chunks_to_do = N * chunks_per_cloud; @@ -95,7 +95,7 @@ std::tuple BallQueryCuda( const int N = p1.size(0); const int P1 = p1.size(1); const int64_t K_64 = K; - const float radius2 = radius * radius; + const auto radius2 = radius * radius; // Output tensor with indices of neighbors for each point in p1 auto long_dtype = lengths1.options().dtype(at::kLong); @@ -110,15 +110,15 @@ std::tuple BallQueryCuda( const size_t blocks = 256; const size_t threads = 256; - AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( p1.scalar_type(), "ball_query_kernel_cuda", ([&] { - BallQueryKernel<<>>( - p1.packed_accessor64(), - p2.packed_accessor64(), + BallQueryKernel<<>>( + p1.packed_accessor64(), + p2.packed_accessor64(), lengths1.packed_accessor64(), lengths2.packed_accessor64(), idxs.packed_accessor64(), - dists.packed_accessor64(), + dists.packed_accessor64(), K_64, radius2); })); diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index ad9dce247..6e0174dfb 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -8,8 +8,8 @@ #include #include +#include #include -#include #include #include @@ -57,7 +57,7 @@ __global__ void KNearestNeighborKernelV0( scalar_t coord1 = points1[n * P1 * D + p1 * D + d]; scalar_t coord2 = points2[n * P2 * D + p2 * D + d]; scalar_t diff = coord1 - coord2; - scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + scalar_t norm_diff = (norm == 2) ? (diff * diff) : ((diff > 0) ? diff : -diff); dist += norm_diff; } mink.add(dist, p2); @@ -102,7 +102,7 @@ __global__ void KNearestNeighborKernelV1( scalar_t dist = 0; for (int d = 0; d < D; ++d) { scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d]; - scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + scalar_t norm_diff = (norm == 2) ? (diff * diff) : ((diff > 0) ? diff : -diff); dist += norm_diff; } mink.add(dist, p2); @@ -167,7 +167,7 @@ __global__ void KNearestNeighborKernelV2( for (int d = 0; d < D; ++d) { int offset = n * P2 * D + p2 * D + d; scalar_t diff = cur_point[d] - points2[offset]; - scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + scalar_t norm_diff = (norm == 2) ? (diff * diff) : ((diff > 0) ? diff : -diff); dist += norm_diff; } mink.add(dist, p2); @@ -238,7 +238,7 @@ __global__ void KNearestNeighborKernelV3( for (int d = 0; d < D; ++d) { int offset = n * P2 * D + p2 * D + d; scalar_t diff = cur_point[d] - points2[offset]; - scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + scalar_t norm_diff = (norm == 2) ? (diff * diff) : ((diff > 0) ? diff : -diff); dist += norm_diff; } mink.add(dist, p2); @@ -367,7 +367,7 @@ std::tuple KNearestNeighborIdxCuda( const size_t threads = 256; const size_t blocks = 256; if (version == 0) { - AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( p1.scalar_type(), "knn_kernel_cuda", ([&] { KNearestNeighborKernelV0<<>>( p1.contiguous().data_ptr(), @@ -384,7 +384,7 @@ std::tuple KNearestNeighborIdxCuda( norm); })); } else if (version == 1) { - AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel1D< KNearestNeighborV1Functor, scalar_t, @@ -406,7 +406,7 @@ std::tuple KNearestNeighborIdxCuda( norm); })); } else if (version == 2) { - AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel2D< KNearestNeighborKernelV2Functor, scalar_t, @@ -430,7 +430,7 @@ std::tuple KNearestNeighborIdxCuda( norm); })); } else if (version == 3) { - AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(p1.scalar_type(), "knn_kernel_cuda", ([&] { DispatchKernel2D< KNearestNeighborKernelV3Functor, scalar_t, @@ -462,17 +462,16 @@ std::tuple KNearestNeighborIdxCuda( // Backward Operators // // ------------------------------------------------------------- // -// TODO(gkioxari) support all data types once AtomicAdd supports doubles. -// Currently, support is for floats only. +template __global__ void KNearestNeighborBackwardKernel( - const float* __restrict__ p1, // (N, P1, D) - const float* __restrict__ p2, // (N, P2, D) + const scalar_t* __restrict__ p1, // (N, P1, D) + const scalar_t* __restrict__ p2, // (N, P2, D) const int64_t* __restrict__ lengths1, // (N,) const int64_t* __restrict__ lengths2, // (N,) const int64_t* __restrict__ idxs, // (N, P1, K) - const float* __restrict__ grad_dists, // (N, P1, K) - float* __restrict__ grad_p1, // (N, P1, D) - float* __restrict__ grad_p2, // (N, P2, D) + const scalar_t* __restrict__ grad_dists, // (N, P1, K) + scalar_t* __restrict__ grad_p1, // (N, P1, D) + scalar_t* __restrict__ grad_p2, // (N, P2, D) const size_t N, const size_t P1, const size_t P2, @@ -493,16 +492,16 @@ __global__ void KNearestNeighborBackwardKernel( const size_t num1 = lengths1[n]; // number of valid points in p1 in batch const size_t num2 = lengths2[n]; // number of valid points in p2 in batch if ((p1_idx < num1) && (k < num2)) { - const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k]; + const scalar_t grad_dist = grad_dists[n * P1 * K + p1_idx * K + k]; // index of point in p2 corresponding to the k-th nearest neighbor const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k]; // If the index is the pad value of -1 then ignore it if (p2_idx == -1) { continue; } - float diff = 0.0; + scalar_t diff = 0.0; if (norm == 1) { - float sign = + scalar_t sign = (p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d]) ? 1.0 : -1.0; @@ -511,8 +510,8 @@ __global__ void KNearestNeighborBackwardKernel( diff = 2.0 * grad_dist * (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); } - atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); - atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); + gpuAtomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); + gpuAtomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); } } } @@ -566,21 +565,23 @@ std::tuple KNearestNeighborBackwardCuda( const int blocks = 64; const int threads = 512; - KNearestNeighborBackwardKernel<<>>( - p1.contiguous().data_ptr(), - p2.contiguous().data_ptr(), - lengths1.contiguous().data_ptr(), - lengths2.contiguous().data_ptr(), - idxs.contiguous().data_ptr(), - grad_dists.contiguous().data_ptr(), - grad_p1.data_ptr(), - grad_p2.data_ptr(), - N, - P1, - P2, - K, - D, - norm); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(p1.scalar_type(), "knn_backward_kernel_cuda", ([&] { + KNearestNeighborBackwardKernel<<>>( + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + idxs.contiguous().data_ptr(), + grad_dists.contiguous().data_ptr(), + grad_p1.data_ptr(), + grad_p2.data_ptr(), + N, + P1, + P2, + K, + D, + norm); + })); AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_p1, grad_p2); diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu index 70cef75c7..538506506 100644 --- a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu @@ -15,19 +15,19 @@ #include #include "utils/warp_reduce.cuh" -template +template __global__ void FarthestPointSamplingKernel( // clang-format off - const at::PackedTensorAccessor64 points, + const at::PackedTensorAccessor64 points, const at::PackedTensorAccessor64 lengths, const at::PackedTensorAccessor64 K, at::PackedTensorAccessor64 idxs, - at::PackedTensorAccessor64 min_point_dist, + at::PackedTensorAccessor64 min_point_dist, const at::PackedTensorAccessor64 start_idxs // clang-format on ) { typedef cub::BlockReduce< - cub::KeyValuePair, + cub::KeyValuePair, block_size, cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY> BlockReduce; @@ -57,16 +57,16 @@ __global__ void FarthestPointSamplingKernel( // Keep track of the maximum of the minimum distance to previously selected // points seen by this thread int64_t max_dist_idx = 0; - float max_dist = -1.0; + scalar_t max_dist = -1.0; // Iterate through all the points in this pointcloud. For already selected // points, the minimum distance to the set of previously selected points // will be 0.0 so they won't be selected again. for (int64_t p = tid; p < lengths[batch_idx]; p += block_size) { // Calculate the distance to the last selected point - float dist2 = 0.0; + scalar_t dist2 = 0.0; for (int64_t d = 0; d < D; ++d) { - float diff = points[batch_idx][selected][d] - points[batch_idx][p][d]; + scalar_t diff = points[batch_idx][selected][d] - points[batch_idx][p][d]; dist2 += (diff * diff); } @@ -74,7 +74,7 @@ __global__ void FarthestPointSamplingKernel( // less than the previous minimum distance of p to the set of selected // points, then updated the corresponding value in min_point_dist // so it always contains the min distance. - const float p_min_dist = min(dist2, min_point_dist[batch_idx][p]); + const scalar_t p_min_dist = min(dist2, min_point_dist[batch_idx][p]); min_point_dist[batch_idx][p] = p_min_dist; // Update the max distance and point idx for this thread. @@ -88,7 +88,7 @@ __global__ void FarthestPointSamplingKernel( selected = BlockReduce(temp_storage) .Reduce( - cub::KeyValuePair(max_dist_idx, max_dist), + cub::KeyValuePair(max_dist_idx, max_dist), cub::ArgMax(), block_size) .key; @@ -109,13 +109,16 @@ at::Tensor FarthestPointSamplingCuda( const at::Tensor& points, // (N, P, 3) const at::Tensor& lengths, // (N,) const at::Tensor& K, // (N,) - const at::Tensor& start_idxs) { + const at::Tensor& start_idxs, // (N, P) + const at::Tensor& min_point_dist + ) { // Check inputs are on the same device - at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2}, - k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4}; - at::CheckedFrom c = "FarthestPointSamplingCuda"; - at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t}); - at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t}); + at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2}, + k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4}, min_point_dist_t{min_point_dist, "min_point_dist", 5}; + at::CheckedFrom c = "FarthestPointSamplingCuda"; + at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t, min_point_dist_t}); + at::checkAllSameType(c, {p_t, min_point_dist_t}); + at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t}); // Set the device for the kernel launch based on the device of points at::cuda::CUDAGuard device_guard(points.device()); @@ -135,7 +138,6 @@ at::Tensor FarthestPointSamplingCuda( // Initialize the output tensor with the sampled indices auto idxs = at::full({N, max_K}, -1, lengths.options()); - auto min_point_dist = at::full({N, P}, 1e10, points.options()); if (N == 0 || P == 0) { AT_CUDA_CHECK(cudaGetLastError()); @@ -158,15 +160,10 @@ at::Tensor FarthestPointSamplingCuda( const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 2); // Create the accessors - auto points_a = points.packed_accessor64(); - auto lengths_a = - lengths.packed_accessor64(); + auto lengths_a = lengths.packed_accessor64(); auto K_a = K.packed_accessor64(); auto idxs_a = idxs.packed_accessor64(); - auto start_idxs_a = - start_idxs.packed_accessor64(); - auto min_point_dist_a = - min_point_dist.packed_accessor64(); + auto start_idxs_a = start_idxs.packed_accessor64(); // TempStorage for the reduction uses static shared memory only. size_t shared_mem = 0; @@ -175,50 +172,124 @@ at::Tensor FarthestPointSamplingCuda( // block. switch (threads) { case 1024: - FarthestPointSamplingKernel<1024> - <<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<1024, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; case 512: - FarthestPointSamplingKernel<512><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<512, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; case 256: - FarthestPointSamplingKernel<256><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<256, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + })); break; case 128: - FarthestPointSamplingKernel<128><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<128, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; case 64: - FarthestPointSamplingKernel<64><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<64, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; case 32: - FarthestPointSamplingKernel<32><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<32, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; case 16: - FarthestPointSamplingKernel<16><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<16, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; case 8: - FarthestPointSamplingKernel<8><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<8, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; case 4: - FarthestPointSamplingKernel<4><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<4, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; case 2: - FarthestPointSamplingKernel<2><<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<2, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); break; default: - FarthestPointSamplingKernel<1024> - <<>>( - points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "fps_kernel_cuda", ([&] { + FarthestPointSamplingKernel<1024, scalar_t> + <<>>( + points.packed_accessor64(), lengths_a, K_a, idxs_a, + min_point_dist.packed_accessor64(), start_idxs_a + ); + }) + ); } AT_CUDA_CHECK(cudaGetLastError()); diff --git a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h index 7b613d358..aa7b8f104 100644 --- a/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h +++ b/pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h @@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda( const at::Tensor& points, const at::Tensor& lengths, const at::Tensor& K, - const at::Tensor& start_idxs); + const at::Tensor& start_idxs, + const at::Tensor& min_point_dist); at::Tensor FarthestPointSamplingCpu( const at::Tensor& points, @@ -56,14 +57,16 @@ at::Tensor FarthestPointSampling( const at::Tensor& points, const at::Tensor& lengths, const at::Tensor& K, - const at::Tensor& start_idxs) { + const at::Tensor& start_idxs, + const at::Tensor& min_point_dist) { if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(points); CHECK_CUDA(lengths); CHECK_CUDA(K); CHECK_CUDA(start_idxs); - return FarthestPointSamplingCuda(points, lengths, K, start_idxs); + CHECK_CUDA(min_point_dist); + return FarthestPointSamplingCuda(points, lengths, K, start_idxs, min_point_dist); #else AT_ERROR("Not compiled with GPU support."); #endif diff --git a/pytorch3d/ops/ball_query.py b/pytorch3d/ops/ball_query.py index 31266c4d2..b852e1dbd 100644 --- a/pytorch3d/ops/ball_query.py +++ b/pytorch3d/ops/ball_query.py @@ -36,13 +36,6 @@ def forward(ctx, p1, p2, lengths1, lengths2, K, radius): @once_differentiable def backward(ctx, grad_dists, grad_idx): p1, p2, lengths1, lengths2, idx = ctx.saved_tensors - # TODO(gkioxari) Change cast to floats once we add support for doubles. - if not (grad_dists.dtype == torch.float32): - grad_dists = grad_dists.float() - if not (p1.dtype == torch.float32): - p1 = p1.float() - if not (p2.dtype == torch.float32): - p2 = p2.float() # Reuse the KNN backward function # by default, norm is 2 diff --git a/pytorch3d/ops/knn.py b/pytorch3d/ops/knn.py index 114334fda..9b5022dcb 100644 --- a/pytorch3d/ops/knn.py +++ b/pytorch3d/ops/knn.py @@ -98,13 +98,6 @@ def forward( def backward(ctx, grad_dists, grad_idx): p1, p2, lengths1, lengths2, idx = ctx.saved_tensors norm = ctx.norm - # TODO(gkioxari) Change cast to floats once we add support for doubles. - if not (grad_dists.dtype == torch.float32): - grad_dists = grad_dists.float() - if not (p1.dtype == torch.float32): - p1 = p1.float() - if not (p2.dtype == torch.float32): - p2 = p2.float() grad_p1, grad_p2 = _C.knn_points_backward( p1, p2, lengths1, lengths2, idx, norm, grad_dists ) diff --git a/pytorch3d/ops/sample_farthest_points.py b/pytorch3d/ops/sample_farthest_points.py index a45b1de22..b0edfd67f 100644 --- a/pytorch3d/ops/sample_farthest_points.py +++ b/pytorch3d/ops/sample_farthest_points.py @@ -74,8 +74,6 @@ def sample_farthest_points( raise ValueError("K and points must have the same batch dimension") # Check dtypes are correct and convert if necessary - if not (points.dtype == torch.float32): - points = points.to(torch.float32) if not (lengths.dtype == torch.int64): lengths = lengths.to(torch.int64) if not (K.dtype == torch.int64): @@ -83,6 +81,12 @@ def sample_farthest_points( # Generate the starting indices for sampling start_idxs = torch.zeros_like(lengths) + + # Generate the minimum point distance array + min_point_dist = torch.full( + (N, P), torch.finfo(points.dtype).max, dtype=points.dtype, device=device + ) + if random_start_point: for n in range(N): # pyre-fixme[6]: For 1st param expected `int` but got `Tensor`. @@ -90,7 +94,7 @@ def sample_farthest_points( with torch.no_grad(): # pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`. - idx = _C.sample_farthest_points(points, lengths, K, start_idxs) + idx = _C.sample_farthest_points(points, lengths, K, start_idxs, min_point_dist) sampled_points = masked_gather(points, idx) return sampled_points, idx