@@ -12,13 +12,14 @@ extern THCState *state;
12
12
// we assume BHWD format in inputImages
13
13
// we assume BHW(YX) format on grids
14
14
15
- int BilinearSamplerBHWD_updateOutput_cuda (THCudaTensor * inputImages , THCudaTensor * grids , THCudaTensor * output )
15
+ int BilinearSamplerBHWD_updateOutput_cuda (THCudaTensor * inputImages , THCudaTensor * grids , THCudaTensor * output , int * device )
16
16
{
17
17
// THCState *state = getCutorchState(L);
18
18
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
19
19
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
20
20
// THCudaTensor *output = (THCudaTensor *)luaT_checkudata(L, 4, "torch.CudaTensor");
21
21
22
+ cudaSetDevice (device [0 ]);
22
23
int success = 0 ;
23
24
success = BilinearSamplerBHWD_updateOutput_cuda_kernel (output -> size [2 ],
24
25
output -> size [1 ],
@@ -27,17 +28,17 @@ int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTenso
27
28
THCudaTensor_size (state , inputImages , 1 ),
28
29
THCudaTensor_size (state , inputImages , 2 ),
29
30
THCudaTensor_size (state , output , 2 ),
30
- THCudaTensor_data (state , inputImages ),
31
+ THCudaTensor_data (state , inputImages ),
31
32
THCudaTensor_stride (state , inputImages , 0 ),
32
33
THCudaTensor_stride (state , inputImages , 3 ),
33
34
THCudaTensor_stride (state , inputImages , 1 ),
34
35
THCudaTensor_stride (state , inputImages , 2 ),
35
- THCudaTensor_data (state , grids ),
36
+ THCudaTensor_data (state , grids ),
36
37
THCudaTensor_stride (state , grids , 0 ),
37
38
THCudaTensor_stride (state , grids , 3 ),
38
39
THCudaTensor_stride (state , grids , 1 ),
39
40
THCudaTensor_stride (state , grids , 2 ),
40
- THCudaTensor_data (state , output ),
41
+ THCudaTensor_data (state , output ),
41
42
THCudaTensor_stride (state , output , 0 ),
42
43
THCudaTensor_stride (state , output , 3 ),
43
44
THCudaTensor_stride (state , output , 1 ),
@@ -52,7 +53,7 @@ int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTenso
52
53
}
53
54
54
55
int BilinearSamplerBHWD_updateGradInput_cuda (THCudaTensor * inputImages , THCudaTensor * grids , THCudaTensor * gradInputImages ,
55
- THCudaTensor * gradGrids , THCudaTensor * gradOutput )
56
+ THCudaTensor * gradGrids , THCudaTensor * gradOutput , int * device )
56
57
{
57
58
// THCState *state = getCutorchState(L);
58
59
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
@@ -61,6 +62,7 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
61
62
// THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor");
62
63
// THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor");
63
64
65
+ cudaSetDevice (device [0 ]);
64
66
int success = 0 ;
65
67
success = BilinearSamplerBHWD_updateGradInput_cuda_kernel (gradOutput -> size [2 ],
66
68
gradOutput -> size [1 ],
@@ -69,27 +71,27 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
69
71
THCudaTensor_size (state , inputImages , 1 ),
70
72
THCudaTensor_size (state , inputImages , 2 ),
71
73
THCudaTensor_size (state , gradOutput , 2 ),
72
- THCudaTensor_data (state , inputImages ),
74
+ THCudaTensor_data (state , inputImages ),
73
75
THCudaTensor_stride (state , inputImages , 0 ),
74
76
THCudaTensor_stride (state , inputImages , 3 ),
75
77
THCudaTensor_stride (state , inputImages , 1 ),
76
78
THCudaTensor_stride (state , inputImages , 2 ),
77
- THCudaTensor_data (state , grids ),
79
+ THCudaTensor_data (state , grids ),
78
80
THCudaTensor_stride (state , grids , 0 ),
79
81
THCudaTensor_stride (state , grids , 3 ),
80
82
THCudaTensor_stride (state , grids , 1 ),
81
83
THCudaTensor_stride (state , grids , 2 ),
82
- THCudaTensor_data (state , gradInputImages ),
84
+ THCudaTensor_data (state , gradInputImages ),
83
85
THCudaTensor_stride (state , gradInputImages , 0 ),
84
86
THCudaTensor_stride (state , gradInputImages , 3 ),
85
87
THCudaTensor_stride (state , gradInputImages , 1 ),
86
88
THCudaTensor_stride (state , gradInputImages , 2 ),
87
- THCudaTensor_data (state , gradGrids ),
89
+ THCudaTensor_data (state , gradGrids ),
88
90
THCudaTensor_stride (state , gradGrids , 0 ),
89
91
THCudaTensor_stride (state , gradGrids , 3 ),
90
92
THCudaTensor_stride (state , gradGrids , 1 ),
91
93
THCudaTensor_stride (state , gradGrids , 2 ),
92
- THCudaTensor_data (state , gradOutput ),
94
+ THCudaTensor_data (state , gradOutput ),
93
95
THCudaTensor_stride (state , gradOutput , 0 ),
94
96
THCudaTensor_stride (state , gradOutput , 3 ),
95
97
THCudaTensor_stride (state , gradOutput , 1 ),
@@ -104,14 +106,15 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
104
106
}
105
107
106
108
int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda (THCudaTensor * inputImages , THCudaTensor * grids ,
107
- THCudaTensor * gradGrids , THCudaTensor * gradOutput )
109
+ THCudaTensor * gradGrids , THCudaTensor * gradOutput , int * device )
108
110
{
109
111
// THCState *state = getCutorchState(L);
110
112
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
111
113
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
112
114
// THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor");
113
115
// THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor");
114
116
117
+ cudaSetDevice (device [0 ]);
115
118
int success = 0 ;
116
119
success = BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda_kernel (
117
120
gradOutput -> size [2 ],
@@ -121,22 +124,22 @@ int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages,
121
124
THCudaTensor_size (state , inputImages , 1 ),
122
125
THCudaTensor_size (state , inputImages , 2 ),
123
126
THCudaTensor_size (state , gradOutput , 2 ),
124
- THCudaTensor_data (state , inputImages ),
127
+ THCudaTensor_data (state , inputImages ),
125
128
THCudaTensor_stride (state , inputImages , 0 ),
126
129
THCudaTensor_stride (state , inputImages , 3 ),
127
130
THCudaTensor_stride (state , inputImages , 1 ),
128
131
THCudaTensor_stride (state , inputImages , 2 ),
129
- THCudaTensor_data (state , grids ),
132
+ THCudaTensor_data (state , grids ),
130
133
THCudaTensor_stride (state , grids , 0 ),
131
134
THCudaTensor_stride (state , grids , 3 ),
132
135
THCudaTensor_stride (state , grids , 1 ),
133
136
THCudaTensor_stride (state , grids , 2 ),
134
- THCudaTensor_data (state , gradGrids ),
137
+ THCudaTensor_data (state , gradGrids ),
135
138
THCudaTensor_stride (state , gradGrids , 0 ),
136
139
THCudaTensor_stride (state , gradGrids , 3 ),
137
140
THCudaTensor_stride (state , gradGrids , 1 ),
138
141
THCudaTensor_stride (state , gradGrids , 2 ),
139
- THCudaTensor_data (state , gradOutput ),
142
+ THCudaTensor_data (state , gradOutput ),
140
143
THCudaTensor_stride (state , gradOutput , 0 ),
141
144
THCudaTensor_stride (state , gradOutput , 3 ),
142
145
THCudaTensor_stride (state , gradOutput , 1 ),
0 commit comments