Skip to content

Commit 85c396f

Browse files
bottlerfacebook-github-bot
authored andcommitted
avoid using torch/extension.h in cuda
Summary: Use aten instead of torch interface in all cuda code. This allows the cuda build to work with pytorch 1.5 with GCC 5 (e.g. the compiler of ubuntu 16.04LTS). This wasn't working. It has been failing with errors like the below, perhaps due to a bug in nvcc. ``` torch/include/torch/csrc/api/include/torch/nn/cloneable.h:68:61: error: invalid static_cast from type ‘const torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> >’ to type ‘torch::OrderedDict<std::basic_string<char>, std::shared_ptr<torch::nn::Module> > ``` Reviewed By: nikhilaravi Differential Revision: D21204029 fbshipit-source-id: ca6bdbcecf42493365e1c23a33fe35e1759fe8b6
1 parent 54b482b commit 85c396f

File tree

9 files changed

+245
-245
lines changed

9 files changed

+245
-245
lines changed

pytorch3d/csrc/compositing/alpha_composite.cu

+34-33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
#include <torch/extension.h>
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/TensorAccessor.h>
45

56
#include <cuda.h>
67
#include <cuda_runtime.h>
@@ -12,10 +13,10 @@
1213
// Currently, support is for floats only.
1314
__global__ void alphaCompositeCudaForwardKernel(
1415
// clang-format off
15-
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
16-
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
17-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
18-
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
16+
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> result,
17+
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
18+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
19+
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
1920
// clang-format on
2021
const int64_t batch_size = result.size(0);
2122
const int64_t C = features.size(0);
@@ -61,12 +62,12 @@ __global__ void alphaCompositeCudaForwardKernel(
6162
// Currently, support is for floats only.
6263
__global__ void alphaCompositeCudaBackwardKernel(
6364
// clang-format off
64-
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
65-
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
66-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
67-
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
68-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
69-
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
65+
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> grad_features,
66+
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_alphas,
67+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_outputs,
68+
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
69+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
70+
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
7071
// clang-format on
7172
const int64_t batch_size = points_idx.size(0);
7273
const int64_t C = features.size(0);
@@ -131,16 +132,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
131132
}
132133
}
133134

134-
torch::Tensor alphaCompositeCudaForward(
135-
const torch::Tensor& features,
136-
const torch::Tensor& alphas,
137-
const torch::Tensor& points_idx) {
135+
at::Tensor alphaCompositeCudaForward(
136+
const at::Tensor& features,
137+
const at::Tensor& alphas,
138+
const at::Tensor& points_idx) {
138139
const int64_t batch_size = points_idx.size(0);
139140
const int64_t C = features.size(0);
140141
const int64_t H = points_idx.size(2);
141142
const int64_t W = points_idx.size(3);
142143

143-
auto result = torch::zeros({batch_size, C, H, W}, features.options());
144+
auto result = at::zeros({batch_size, C, H, W}, features.options());
144145

145146
const dim3 threadsPerBlock(64);
146147
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
@@ -149,22 +150,22 @@ torch::Tensor alphaCompositeCudaForward(
149150
// doubles. Currently, support is for floats only.
150151
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
151152
// clang-format off
152-
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
153-
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
154-
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
155-
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
153+
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
154+
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
155+
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
156+
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
156157
// clang-format on
157158

158159
return result;
159160
}
160161

161-
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
162-
const torch::Tensor& grad_outputs,
163-
const torch::Tensor& features,
164-
const torch::Tensor& alphas,
165-
const torch::Tensor& points_idx) {
166-
auto grad_features = torch::zeros_like(features);
167-
auto grad_alphas = torch::zeros_like(alphas);
162+
std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
163+
const at::Tensor& grad_outputs,
164+
const at::Tensor& features,
165+
const at::Tensor& alphas,
166+
const at::Tensor& points_idx) {
167+
auto grad_features = at::zeros_like(features);
168+
auto grad_alphas = at::zeros_like(alphas);
168169

169170
const int64_t bs = alphas.size(0);
170171

@@ -175,12 +176,12 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
175176
// doubles. Currently, support is for floats only.
176177
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
177178
// clang-format off
178-
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
179-
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
180-
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
181-
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
182-
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
183-
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
179+
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
180+
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
181+
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
182+
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
183+
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
184+
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
184185
// clang-format on
185186

186187
return std::make_tuple(grad_features, grad_alphas);

pytorch3d/csrc/compositing/norm_weighted_sum.cu

+34-33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
#include <torch/extension.h>
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/TensorAccessor.h>
45

56
#include <cuda.h>
67
#include <cuda_runtime.h>
@@ -14,10 +15,10 @@ __constant__ const float kEpsilon = 1e-4;
1415
// Currently, support is for floats only.
1516
__global__ void weightedSumNormCudaForwardKernel(
1617
// clang-format off
17-
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
18-
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
19-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
20-
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
18+
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> result,
19+
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
20+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
21+
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
2122
// clang-format on
2223
const int64_t batch_size = result.size(0);
2324
const int64_t C = features.size(0);
@@ -76,12 +77,12 @@ __global__ void weightedSumNormCudaForwardKernel(
7677
// Currently, support is for floats only.
7778
__global__ void weightedSumNormCudaBackwardKernel(
7879
// clang-format off
79-
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
80-
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
81-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
82-
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
83-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
84-
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
80+
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> grad_features,
81+
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_alphas,
82+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_outputs,
83+
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
84+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
85+
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
8586
// clang-format on
8687
const int64_t batch_size = points_idx.size(0);
8788
const int64_t C = features.size(0);
@@ -146,16 +147,16 @@ __global__ void weightedSumNormCudaBackwardKernel(
146147
}
147148
}
148149

149-
torch::Tensor weightedSumNormCudaForward(
150-
const torch::Tensor& features,
151-
const torch::Tensor& alphas,
152-
const torch::Tensor& points_idx) {
150+
at::Tensor weightedSumNormCudaForward(
151+
const at::Tensor& features,
152+
const at::Tensor& alphas,
153+
const at::Tensor& points_idx) {
153154
const int64_t batch_size = points_idx.size(0);
154155
const int64_t C = features.size(0);
155156
const int64_t H = points_idx.size(2);
156157
const int64_t W = points_idx.size(3);
157158

158-
auto result = torch::zeros({batch_size, C, H, W}, features.options());
159+
auto result = at::zeros({batch_size, C, H, W}, features.options());
159160

160161
const dim3 threadsPerBlock(64);
161162
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
@@ -164,22 +165,22 @@ torch::Tensor weightedSumNormCudaForward(
164165
// doubles. Currently, support is for floats only.
165166
// clang-format off
166167
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
167-
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
168-
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
169-
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
170-
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
168+
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
169+
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
170+
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
171+
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
171172
// clang-format on
172173

173174
return result;
174175
}
175176

176-
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
177-
const torch::Tensor& grad_outputs,
178-
const torch::Tensor& features,
179-
const torch::Tensor& alphas,
180-
const torch::Tensor& points_idx) {
181-
auto grad_features = torch::zeros_like(features);
182-
auto grad_alphas = torch::zeros_like(alphas);
177+
std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
178+
const at::Tensor& grad_outputs,
179+
const at::Tensor& features,
180+
const at::Tensor& alphas,
181+
const at::Tensor& points_idx) {
182+
auto grad_features = at::zeros_like(features);
183+
auto grad_alphas = at::zeros_like(alphas);
183184

184185
const int64_t bs = points_idx.size(0);
185186

@@ -190,12 +191,12 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
190191
// doubles. Currently, support is for floats only.
191192
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
192193
// clang-format off
193-
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
194-
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
195-
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
196-
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
197-
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
198-
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
194+
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
195+
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
196+
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
197+
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
198+
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
199+
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
199200
// clang-format on
200201

201202
return std::make_tuple(grad_features, grad_alphas);

pytorch3d/csrc/compositing/weighted_sum.cu

+34-33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
#include <torch/extension.h>
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/TensorAccessor.h>
45

56
#include <cuda.h>
67
#include <cuda_runtime.h>
@@ -12,10 +13,10 @@
1213
// Currently, support is for floats only.
1314
__global__ void weightedSumCudaForwardKernel(
1415
// clang-format off
15-
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> result,
16-
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
17-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
18-
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
16+
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> result,
17+
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
18+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
19+
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
1920
// clang-format on
2021
const int64_t batch_size = result.size(0);
2122
const int64_t C = features.size(0);
@@ -58,12 +59,12 @@ __global__ void weightedSumCudaForwardKernel(
5859
// Currently, support is for floats only.
5960
__global__ void weightedSumCudaBackwardKernel(
6061
// clang-format off
61-
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> grad_features,
62-
torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_alphas,
63-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> grad_outputs,
64-
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> features,
65-
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits> alphas,
66-
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> points_idx) {
62+
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> grad_features,
63+
at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_alphas,
64+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> grad_outputs,
65+
const at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> features,
66+
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
67+
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
6768
// clang-format on
6869
const int64_t batch_size = points_idx.size(0);
6970
const int64_t C = features.size(0);
@@ -105,16 +106,16 @@ __global__ void weightedSumCudaBackwardKernel(
105106
}
106107
}
107108

108-
torch::Tensor weightedSumCudaForward(
109-
const torch::Tensor& features,
110-
const torch::Tensor& alphas,
111-
const torch::Tensor& points_idx) {
109+
at::Tensor weightedSumCudaForward(
110+
const at::Tensor& features,
111+
const at::Tensor& alphas,
112+
const at::Tensor& points_idx) {
112113
const int64_t batch_size = points_idx.size(0);
113114
const int64_t C = features.size(0);
114115
const int64_t H = points_idx.size(2);
115116
const int64_t W = points_idx.size(3);
116117

117-
auto result = torch::zeros({batch_size, C, H, W}, features.options());
118+
auto result = at::zeros({batch_size, C, H, W}, features.options());
118119

119120
const dim3 threadsPerBlock(64);
120121
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
@@ -123,22 +124,22 @@ torch::Tensor weightedSumCudaForward(
123124
// doubles. Currently, support is for floats only.
124125
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
125126
// clang-format off
126-
result.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
127-
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
128-
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
129-
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
127+
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
128+
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
129+
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
130+
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
130131
// clang-format on
131132

132133
return result;
133134
}
134135

135-
std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
136-
const torch::Tensor& grad_outputs,
137-
const torch::Tensor& features,
138-
const torch::Tensor& alphas,
139-
const torch::Tensor& points_idx) {
140-
auto grad_features = torch::zeros_like(features);
141-
auto grad_alphas = torch::zeros_like(alphas);
136+
std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
137+
const at::Tensor& grad_outputs,
138+
const at::Tensor& features,
139+
const at::Tensor& alphas,
140+
const at::Tensor& points_idx) {
141+
auto grad_features = at::zeros_like(features);
142+
auto grad_alphas = at::zeros_like(alphas);
142143

143144
const int64_t bs = points_idx.size(0);
144145

@@ -149,12 +150,12 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
149150
// doubles. Currently, support is for floats only.
150151
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
151152
// clang-format off
152-
grad_features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
153-
grad_alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
154-
grad_outputs.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
155-
features.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
156-
alphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
157-
points_idx.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>());
153+
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
154+
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
155+
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
156+
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
157+
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
158+
points_idx.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>());
158159
// clang-format on
159160

160161
return std::make_tuple(grad_features, grad_alphas);

pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33
#include <ATen/ATen.h>
4-
#include <torch/extension.h>
54

65
// Kernel for inputs_packed of shape (F, D), where D > 1
76
template <typename scalar_t>

0 commit comments

Comments
 (0)