Skip to content

Commit d8ef434

Browse files
committed
attempt to fix issue #9
1 parent a7de67a commit d8ef434

File tree

4 files changed

+43
-33
lines changed

4 files changed

+43
-33
lines changed

script/functions/stn.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,33 @@
22
import torch
33
from torch.autograd import Function
44
from _ext import my_lib
5-
5+
from cffi import FFI
6+
ffi = FFI()
67

78
class STNFunction(Function):
89
def forward(self, input1, input2):
910
self.input1 = input1
1011
self.input2 = input2
12+
self.device_c = ffi.new("int *")
1113
output = torch.zeros(input1.size()[0], input2.size()[1], input2.size()[2], input1.size()[3])
14+
#print('decice %d' % torch.cuda.current_device())
15+
self.device = torch.cuda.current_device()
16+
self.device_c[0] = self.device
1217
if not input1.is_cuda:
1318
my_lib.BilinearSamplerBHWD_updateOutput(input1, input2, output)
1419
else:
15-
output = output.cuda()
16-
my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output)
20+
output = output.cuda(self.device)
21+
my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output, self.device_c)
1722
return output
1823

1924
def backward(self, grad_output):
2025
grad_input1 = torch.zeros(self.input1.size())
2126
grad_input2 = torch.zeros(self.input2.size())
27+
#print('backward decice %d' % self.device)
2228
if not grad_output.is_cuda:
2329
my_lib.BilinearSamplerBHWD_updateGradInput(self.input1, self.input2, grad_input1, grad_input2, grad_output)
2430
else:
25-
grad_input1 = grad_input1.cuda()
26-
grad_input2 = grad_input2.cuda()
27-
my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output)
31+
grad_input1 = grad_input1.cuda(self.device)
32+
grad_input2 = grad_input2.cuda(self.device)
33+
my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output, self.device_c)
2834
return grad_input1, grad_input2

script/src/my_lib_cuda.c

+18-15
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@ extern THCState *state;
1212
// we assume BHWD format in inputImages
1313
// we assume BHW(YX) format on grids
1414

15-
int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output)
15+
int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output, int * device)
1616
{
1717
// THCState *state = getCutorchState(L);
1818
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
1919
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
2020
// THCudaTensor *output = (THCudaTensor *)luaT_checkudata(L, 4, "torch.CudaTensor");
2121

22+
cudaSetDevice(device[0]);
2223
int success = 0;
2324
success = BilinearSamplerBHWD_updateOutput_cuda_kernel(output->size[2],
2425
output->size[1],
@@ -27,17 +28,17 @@ int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTenso
2728
THCudaTensor_size(state, inputImages, 1),
2829
THCudaTensor_size(state, inputImages, 2),
2930
THCudaTensor_size(state, output, 2),
30-
THCudaTensor_data(state, inputImages),
31+
THCudaTensor_data(state, inputImages),
3132
THCudaTensor_stride(state, inputImages, 0),
3233
THCudaTensor_stride(state, inputImages, 3),
3334
THCudaTensor_stride(state, inputImages, 1),
3435
THCudaTensor_stride(state, inputImages, 2),
35-
THCudaTensor_data(state, grids),
36+
THCudaTensor_data(state, grids),
3637
THCudaTensor_stride(state, grids, 0),
3738
THCudaTensor_stride(state, grids, 3),
3839
THCudaTensor_stride(state, grids, 1),
3940
THCudaTensor_stride(state, grids, 2),
40-
THCudaTensor_data(state, output),
41+
THCudaTensor_data(state, output),
4142
THCudaTensor_stride(state, output, 0),
4243
THCudaTensor_stride(state, output, 3),
4344
THCudaTensor_stride(state, output, 1),
@@ -52,7 +53,7 @@ int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTenso
5253
}
5354

5455
int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages,
55-
THCudaTensor *gradGrids, THCudaTensor *gradOutput)
56+
THCudaTensor *gradGrids, THCudaTensor *gradOutput, int * device)
5657
{
5758
// THCState *state = getCutorchState(L);
5859
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
@@ -61,6 +62,7 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
6162
// THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor");
6263
// THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor");
6364

65+
cudaSetDevice(device[0]);
6466
int success = 0;
6567
success = BilinearSamplerBHWD_updateGradInput_cuda_kernel(gradOutput->size[2],
6668
gradOutput->size[1],
@@ -69,27 +71,27 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
6971
THCudaTensor_size(state, inputImages, 1),
7072
THCudaTensor_size(state, inputImages, 2),
7173
THCudaTensor_size(state, gradOutput, 2),
72-
THCudaTensor_data(state, inputImages),
74+
THCudaTensor_data(state, inputImages),
7375
THCudaTensor_stride(state, inputImages, 0),
7476
THCudaTensor_stride(state, inputImages, 3),
7577
THCudaTensor_stride(state, inputImages, 1),
7678
THCudaTensor_stride(state, inputImages, 2),
77-
THCudaTensor_data(state, grids),
79+
THCudaTensor_data(state, grids),
7880
THCudaTensor_stride(state, grids, 0),
7981
THCudaTensor_stride(state, grids, 3),
8082
THCudaTensor_stride(state, grids, 1),
8183
THCudaTensor_stride(state, grids, 2),
82-
THCudaTensor_data(state, gradInputImages),
84+
THCudaTensor_data(state, gradInputImages),
8385
THCudaTensor_stride(state, gradInputImages, 0),
8486
THCudaTensor_stride(state, gradInputImages, 3),
8587
THCudaTensor_stride(state, gradInputImages, 1),
8688
THCudaTensor_stride(state, gradInputImages, 2),
87-
THCudaTensor_data(state, gradGrids),
89+
THCudaTensor_data(state, gradGrids),
8890
THCudaTensor_stride(state, gradGrids, 0),
8991
THCudaTensor_stride(state, gradGrids, 3),
9092
THCudaTensor_stride(state, gradGrids, 1),
9193
THCudaTensor_stride(state, gradGrids, 2),
92-
THCudaTensor_data(state, gradOutput),
94+
THCudaTensor_data(state, gradOutput),
9395
THCudaTensor_stride(state, gradOutput, 0),
9496
THCudaTensor_stride(state, gradOutput, 3),
9597
THCudaTensor_stride(state, gradOutput, 1),
@@ -104,14 +106,15 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
104106
}
105107

106108
int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor *grids,
107-
THCudaTensor *gradGrids, THCudaTensor *gradOutput)
109+
THCudaTensor *gradGrids, THCudaTensor *gradOutput, int * device)
108110
{
109111
// THCState *state = getCutorchState(L);
110112
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
111113
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
112114
// THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor");
113115
// THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor");
114116

117+
cudaSetDevice(device[0]);
115118
int success = 0;
116119
success = BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda_kernel(
117120
gradOutput->size[2],
@@ -121,22 +124,22 @@ int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages,
121124
THCudaTensor_size(state, inputImages, 1),
122125
THCudaTensor_size(state, inputImages, 2),
123126
THCudaTensor_size(state, gradOutput, 2),
124-
THCudaTensor_data(state, inputImages),
127+
THCudaTensor_data(state, inputImages),
125128
THCudaTensor_stride(state, inputImages, 0),
126129
THCudaTensor_stride(state, inputImages, 3),
127130
THCudaTensor_stride(state, inputImages, 1),
128131
THCudaTensor_stride(state, inputImages, 2),
129-
THCudaTensor_data(state, grids),
132+
THCudaTensor_data(state, grids),
130133
THCudaTensor_stride(state, grids, 0),
131134
THCudaTensor_stride(state, grids, 3),
132135
THCudaTensor_stride(state, grids, 1),
133136
THCudaTensor_stride(state, grids, 2),
134-
THCudaTensor_data(state, gradGrids),
137+
THCudaTensor_data(state, gradGrids),
135138
THCudaTensor_stride(state, gradGrids, 0),
136139
THCudaTensor_stride(state, gradGrids, 3),
137140
THCudaTensor_stride(state, gradGrids, 1),
138141
THCudaTensor_stride(state, gradGrids, 2),
139-
THCudaTensor_data(state, gradOutput),
142+
THCudaTensor_data(state, gradOutput),
140143
THCudaTensor_stride(state, gradOutput, 0),
141144
THCudaTensor_stride(state, gradOutput, 3),
142145
THCudaTensor_stride(state, gradOutput, 1),

script/src/my_lib_cuda.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
// we assume BHWD format in inputImages
33
// we assume BHW(YX) format on grids
44

5-
int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output);
5+
int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output, int *);
66

77
int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages,
8-
THCudaTensor *gradGrids, THCudaTensor *gradOutput);
8+
THCudaTensor *gradGrids, THCudaTensor *gradOutput, int *);
99

1010
int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor *grids,
11-
THCudaTensor *gradGrids, THCudaTensor *gradOutput);
11+
THCudaTensor *gradGrids, THCudaTensor *gradOutput, int *);

script/test.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,16 @@
4444
out.backward(input1.data)
4545
print(input1.grad.size(), 'time:', time.time() - start)
4646

47-
input1 = input1.cuda()
48-
input2 = input2.cuda()
49-
50-
start = time.time()
51-
out = s(input1, input2)
52-
print(out.size(), 'time:', time.time() - start)
53-
start = time.time()
54-
out.backward(input1.data)
55-
print('time:', time.time() - start)
47+
with torch.cuda.device(3):
48+
input1 = input1.cuda()
49+
input2 = input2.cuda()
50+
start = time.time()
51+
out = s(input1, input2)
52+
print(out.size(), 'time:', time.time() - start)
53+
start = time.time()
54+
#out.backward(input1.data.cuda())
55+
torch.sum(out).backward()
56+
print('time:', time.time() - start)
5657

5758
input = Variable(torch.from_numpy(np.array([[3.6]], dtype=np.float32)), requires_grad = True)
5859

0 commit comments

Comments
 (0)