1
1
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
2
3
- #include < torch/extension.h>
3
+ #include < ATen/ATen.h>
4
+ #include < ATen/core/TensorAccessor.h>
4
5
5
6
#include < cuda.h>
6
7
#include < cuda_runtime.h>
@@ -14,10 +15,10 @@ __constant__ const float kEpsilon = 1e-4;
14
15
// Currently, support is for floats only.
15
16
__global__ void weightedSumNormCudaForwardKernel (
16
17
// clang-format off
17
- torch ::PackedTensorAccessor64<float , 4 , torch ::RestrictPtrTraits> result,
18
- const torch ::PackedTensorAccessor64<float , 2 , torch ::RestrictPtrTraits> features,
19
- const torch ::PackedTensorAccessor64<float , 4 , torch ::RestrictPtrTraits> alphas,
20
- const torch ::PackedTensorAccessor64<int64_t , 4 , torch ::RestrictPtrTraits> points_idx) {
18
+ at ::PackedTensorAccessor64<float , 4 , at ::RestrictPtrTraits> result,
19
+ const at ::PackedTensorAccessor64<float , 2 , at ::RestrictPtrTraits> features,
20
+ const at ::PackedTensorAccessor64<float , 4 , at ::RestrictPtrTraits> alphas,
21
+ const at ::PackedTensorAccessor64<int64_t , 4 , at ::RestrictPtrTraits> points_idx) {
21
22
// clang-format on
22
23
const int64_t batch_size = result.size (0 );
23
24
const int64_t C = features.size (0 );
@@ -76,12 +77,12 @@ __global__ void weightedSumNormCudaForwardKernel(
76
77
// Currently, support is for floats only.
77
78
__global__ void weightedSumNormCudaBackwardKernel (
78
79
// clang-format off
79
- torch ::PackedTensorAccessor64<float , 2 , torch ::RestrictPtrTraits> grad_features,
80
- torch ::PackedTensorAccessor64<float , 4 , torch ::RestrictPtrTraits> grad_alphas,
81
- const torch ::PackedTensorAccessor64<float , 4 , torch ::RestrictPtrTraits> grad_outputs,
82
- const torch ::PackedTensorAccessor64<float , 2 , torch ::RestrictPtrTraits> features,
83
- const torch ::PackedTensorAccessor64<float , 4 , torch ::RestrictPtrTraits> alphas,
84
- const torch ::PackedTensorAccessor64<int64_t , 4 , torch ::RestrictPtrTraits> points_idx) {
80
+ at ::PackedTensorAccessor64<float , 2 , at ::RestrictPtrTraits> grad_features,
81
+ at ::PackedTensorAccessor64<float , 4 , at ::RestrictPtrTraits> grad_alphas,
82
+ const at ::PackedTensorAccessor64<float , 4 , at ::RestrictPtrTraits> grad_outputs,
83
+ const at ::PackedTensorAccessor64<float , 2 , at ::RestrictPtrTraits> features,
84
+ const at ::PackedTensorAccessor64<float , 4 , at ::RestrictPtrTraits> alphas,
85
+ const at ::PackedTensorAccessor64<int64_t , 4 , at ::RestrictPtrTraits> points_idx) {
85
86
// clang-format on
86
87
const int64_t batch_size = points_idx.size (0 );
87
88
const int64_t C = features.size (0 );
@@ -146,16 +147,16 @@ __global__ void weightedSumNormCudaBackwardKernel(
146
147
}
147
148
}
148
149
149
- torch ::Tensor weightedSumNormCudaForward (
150
- const torch ::Tensor& features,
151
- const torch ::Tensor& alphas,
152
- const torch ::Tensor& points_idx) {
150
+ at ::Tensor weightedSumNormCudaForward (
151
+ const at ::Tensor& features,
152
+ const at ::Tensor& alphas,
153
+ const at ::Tensor& points_idx) {
153
154
const int64_t batch_size = points_idx.size (0 );
154
155
const int64_t C = features.size (0 );
155
156
const int64_t H = points_idx.size (2 );
156
157
const int64_t W = points_idx.size (3 );
157
158
158
- auto result = torch ::zeros ({batch_size, C, H, W}, features.options ());
159
+ auto result = at ::zeros ({batch_size, C, H, W}, features.options ());
159
160
160
161
const dim3 threadsPerBlock (64 );
161
162
const dim3 numBlocks (batch_size, 1024 / batch_size + 1 );
@@ -164,22 +165,22 @@ torch::Tensor weightedSumNormCudaForward(
164
165
// doubles. Currently, support is for floats only.
165
166
// clang-format off
166
167
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>> (
167
- result.packed_accessor64 <float , 4 , torch ::RestrictPtrTraits>(),
168
- features.packed_accessor64 <float , 2 , torch ::RestrictPtrTraits>(),
169
- alphas.packed_accessor64 <float , 4 , torch ::RestrictPtrTraits>(),
170
- points_idx.packed_accessor64 <int64_t , 4 , torch ::RestrictPtrTraits>());
168
+ result.packed_accessor64 <float , 4 , at ::RestrictPtrTraits>(),
169
+ features.packed_accessor64 <float , 2 , at ::RestrictPtrTraits>(),
170
+ alphas.packed_accessor64 <float , 4 , at ::RestrictPtrTraits>(),
171
+ points_idx.packed_accessor64 <int64_t , 4 , at ::RestrictPtrTraits>());
171
172
// clang-format on
172
173
173
174
return result;
174
175
}
175
176
176
- std::tuple<torch ::Tensor, torch ::Tensor> weightedSumNormCudaBackward (
177
- const torch ::Tensor& grad_outputs,
178
- const torch ::Tensor& features,
179
- const torch ::Tensor& alphas,
180
- const torch ::Tensor& points_idx) {
181
- auto grad_features = torch ::zeros_like (features);
182
- auto grad_alphas = torch ::zeros_like (alphas);
177
+ std::tuple<at ::Tensor, at ::Tensor> weightedSumNormCudaBackward (
178
+ const at ::Tensor& grad_outputs,
179
+ const at ::Tensor& features,
180
+ const at ::Tensor& alphas,
181
+ const at ::Tensor& points_idx) {
182
+ auto grad_features = at ::zeros_like (features);
183
+ auto grad_alphas = at ::zeros_like (alphas);
183
184
184
185
const int64_t bs = points_idx.size (0 );
185
186
@@ -190,12 +191,12 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
190
191
// doubles. Currently, support is for floats only.
191
192
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>> (
192
193
// clang-format off
193
- grad_features.packed_accessor64 <float , 2 , torch ::RestrictPtrTraits>(),
194
- grad_alphas.packed_accessor64 <float , 4 , torch ::RestrictPtrTraits>(),
195
- grad_outputs.packed_accessor64 <float , 4 , torch ::RestrictPtrTraits>(),
196
- features.packed_accessor64 <float , 2 , torch ::RestrictPtrTraits>(),
197
- alphas.packed_accessor64 <float , 4 , torch ::RestrictPtrTraits>(),
198
- points_idx.packed_accessor64 <int64_t , 4 , torch ::RestrictPtrTraits>());
194
+ grad_features.packed_accessor64 <float , 2 , at ::RestrictPtrTraits>(),
195
+ grad_alphas.packed_accessor64 <float , 4 , at ::RestrictPtrTraits>(),
196
+ grad_outputs.packed_accessor64 <float , 4 , at ::RestrictPtrTraits>(),
197
+ features.packed_accessor64 <float , 2 , at ::RestrictPtrTraits>(),
198
+ alphas.packed_accessor64 <float , 4 , at ::RestrictPtrTraits>(),
199
+ points_idx.packed_accessor64 <int64_t , 4 , at ::RestrictPtrTraits>());
199
200
// clang-format on
200
201
201
202
return std::make_tuple (grad_features, grad_alphas);
0 commit comments