24
24
25
25
namespace nteffectivetransformer {
26
26
27
- // gelu code from
27
+ // gelu code from
28
28
// https://github.com/NVIDIA/DeepLearningExamples/blob/master/FasterTransformer/v1/fastertransformer/cuda/cuda_kernels.cu#L26-L45
29
29
template <typename T>
30
30
__inline__ __device__
31
31
T gelu (T x)
32
32
{
33
- float cdf = 0 .5f *
33
+ float cdf = 0 .5f *
34
34
(1 .0f + tanhf ((0 .7978845608028654f * (x + 0 .044715f * x * x * x))));
35
35
return x * cdf;
36
36
}
37
37
38
- // reduce code from
38
+ // reduce code from
39
39
// https://github.com/NVIDIA/DeepLearningExamples/blob/master/FasterTransformer/v1/fastertransformer/cuda/cuda_kernels.cu#L47-L73
40
40
41
41
#define FINAL_MASK 0xffffffff
@@ -53,9 +53,9 @@ template <typename T>
53
53
__inline__ __device__
54
54
T blockReduceSum (T val)
55
55
{
56
- static __shared__ T shared[32 ];
57
- int lane = threadIdx .x & 0x1f ;
58
- int wid = threadIdx .x >> 5 ;
56
+ static __shared__ T shared[32 ];
57
+ int lane = threadIdx .x & 0x1f ;
58
+ int wid = threadIdx .x >> 5 ;
59
59
60
60
val = warpReduceSum<T>(val);
61
61
@@ -71,7 +71,7 @@ T blockReduceSum(T val)
71
71
// / ***************************** add_bias + gelu *****************************
72
72
73
73
template <typename T>
74
- __global__
74
+ __global__
75
75
void add_bias_act (T* out, const T* bias, int m, int n)
76
76
{
77
77
T val, reg_bias;
@@ -112,9 +112,9 @@ template void add_bias_act_kernelLauncher<float>(
112
112
// / ************************** add_bias + layer_norm **************************
113
113
114
114
template <typename T>
115
- __global__
115
+ __global__
116
116
void add_bias_input_layernorm (
117
- T* out, const T* input, const T* bias, const T* gamma,
117
+ T* out, const T* input, const T* bias, const T* gamma,
118
118
const T* beta, int m, int n)
119
119
{
120
120
int tid = threadIdx .x ;
@@ -126,7 +126,7 @@ void add_bias_input_layernorm(
126
126
127
127
float local_out = 0 .0f ;
128
128
for (int i = tid; i < n; i += blockDim .x )
129
- local_out += (float )(out[blockIdx .x * n + i]
129
+ local_out += (float )(out[blockIdx .x * n + i]
130
130
+ input[blockIdx .x * n + i] + __ldg (&bias[i]));
131
131
132
132
mean = blockReduceSum<float >(local_out);
@@ -141,14 +141,14 @@ void add_bias_input_layernorm(
141
141
__syncthreads ();
142
142
143
143
for (int i = tid; i < n; i += blockDim .x )
144
- out[blockIdx .x * n + i] =
145
- (T)(((local_out - s_mean) * rsqrtf (s_variance))
144
+ out[blockIdx .x * n + i] =
145
+ (T)(((local_out - s_mean) * rsqrtf (s_variance))
146
146
* (float )(__ldg (&gamma [i])) + (float )(__ldg (&beta[i])));
147
147
}
148
148
149
149
template <typename T>
150
150
void add_bias_input_layernorm_kernelLauncher (
151
- T* out, const T* input, const T* bias,
151
+ T* out, const T* input, const T* bias,
152
152
const T* gamma, const T* beta, int m, int n, cudaStream_t stream)
153
153
{
154
154
assert (n < 1024 );
@@ -159,28 +159,28 @@ void add_bias_input_layernorm_kernelLauncher(
159
159
}
160
160
161
161
template void add_bias_input_layernorm_kernelLauncher<float >(
162
- float * out, const float * input,
163
- const float * bias, const float * gamma, const float * beta,
162
+ float * out, const float * input,
163
+ const float * bias, const float * gamma, const float * beta,
164
164
int m, int n, cudaStream_t stream);
165
165
166
166
// / *********************************** fin ***********************************
167
167
168
168
169
169
// / *********************** compresse transformer input ***********************
170
170
171
- __global__
171
+ __global__
172
172
void compress_bert_input (
173
173
// const T* from_tensor,
174
- const int * mask, const int * prefix_sum,
174
+ const int * mask, const int * prefix_sum,
175
175
// T* to_tensor,
176
176
int * batch_idx, int * word_idx,
177
- int batch_size , int seq_len, int hidden_dim)
177
+ int batch_size , int seq_len, int hidden_dim)
178
178
{
179
179
int bid = blockIdx .y ; // batch
180
- int wid = blockIdx .x ; // word
181
- int tid = threadIdx .x ; //
182
-
183
- // / 1. count pos for from tensor
180
+ int wid = blockIdx .x ; // word
181
+ int tid = threadIdx .x ; //
182
+
183
+ // / 1. count pos for from tensor
184
184
int mask_idx = bid * seq_len + wid;
185
185
186
186
if (mask[mask_idx] > 0.5 ) {
@@ -191,7 +191,7 @@ void compress_bert_input(
191
191
batch_idx[valid_idx] = bid;
192
192
word_idx[valid_idx] = wid;
193
193
}
194
-
194
+
195
195
// /// 3. copy src data
196
196
// float* src_ptr = (float*)from_tensor;
197
197
// float* dst_ptr = (float*)to_tensor;
@@ -203,10 +203,10 @@ void compress_bert_input(
203
203
204
204
void compressBertInput_kernelLauncher (
205
205
// const T* from_tensor,
206
- const int * mask, const int * prefix_sum,
206
+ const int * mask, const int * prefix_sum,
207
207
// T* to_tensor,
208
208
int * batch_idx, int * word_idx,
209
- int batch_size , int seq_len, int hidden_dim, cudaStream_t stream)
209
+ int batch_size , int seq_len, int hidden_dim, cudaStream_t stream)
210
210
{
211
211
// / TODO : fp32
212
212
dim3 grid (seq_len, batch_size);
@@ -215,7 +215,7 @@ void compressBertInput_kernelLauncher(
215
215
assert (hidden_dim <= 1024 );
216
216
compress_bert_input<<<grid, block, 0 , stream>>> (
217
217
// from_tensor,
218
- mask, prefix_sum,
218
+ mask, prefix_sum,
219
219
// to_tensor,
220
220
batch_idx, word_idx,
221
221
batch_size , seq_len, hidden_dim);
@@ -229,11 +229,11 @@ template<typename T>
229
229
__global__
230
230
void restore_bert_output (
231
231
T* to_tensor,
232
- const T* from_tensor, const int * batch_idx, const int * word_idx,
233
- int valid_word_num, int seq_len, int hidden_dim)
232
+ const T* from_tensor, const int * batch_idx, const int * word_idx,
233
+ int valid_word_num, int seq_len, int hidden_dim)
234
234
{
235
235
int bid = batch_idx[blockIdx .x ];
236
- int wid = word_idx[blockIdx .x ];
236
+ int wid = word_idx[blockIdx .x ];
237
237
int tid = threadIdx .x ;
238
238
int vid = blockIdx .x ;
239
239
@@ -248,24 +248,24 @@ void restore_bert_output(
248
248
template <typename T>
249
249
void restoreBertOutput_kernelLauncher (
250
250
T* to_tensor,
251
- const T* from_tensor, const int * batch_idx, const int * word_idx,
252
- int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream)
251
+ const T* from_tensor, const int * batch_idx, const int * word_idx,
252
+ int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream)
253
253
{
254
254
// TODO : fp32
255
255
dim3 grid (valid_word_num);
256
256
dim3 block (hidden_dim);
257
257
assert (hidden_dim <= 1024 );
258
258
restore_bert_output<<<grid, block, 0 , stream>>> (
259
- to_tensor,
259
+ to_tensor,
260
260
from_tensor, batch_idx, word_idx,
261
261
valid_word_num, seq_len, hidden_dim);
262
262
}
263
263
264
264
template void restoreBertOutput_kernelLauncher<float >(
265
265
float * to_tensor,
266
- const float * from_tensor, const int * batch_idx, const int * word_idx,
266
+ const float * from_tensor, const int * batch_idx, const int * word_idx,
267
267
int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream);
268
-
268
+
269
269
// / *********************************** fin ***********************************
270
270
271
271
// / ***************************** exclusive scan ******************************
@@ -279,14 +279,14 @@ int ELEMENTS_PER_BLOCK = THREADS_PER_BLOCK * 2;
279
279
#define LOG_MEM_BANKS 5
280
280
#define CONFLICT_FREE_OFFSET (n ) ((n) >> LOG_MEM_BANKS)
281
281
282
- __global__ void prescan_large (int *output, const int *input, int n, int *sums)
282
+ __global__ void prescan_large (int *output, const int *input, int n, int *sums)
283
283
{
284
284
extern __shared__ int temp[];
285
285
286
286
int blockID = blockIdx .x ;
287
287
int threadID = threadIdx .x ;
288
288
int blockOffset = blockID * n;
289
-
289
+
290
290
int ai = threadID;
291
291
int bi = threadID + (n / 2 );
292
292
int bankOffsetA = CONFLICT_FREE_OFFSET (ai);
@@ -312,11 +312,11 @@ __global__ void prescan_large(int *output, const int *input, int n, int *sums)
312
312
__syncthreads ();
313
313
314
314
315
- if (threadID == 0 ) {
315
+ if (threadID == 0 ) {
316
316
sums[blockID] = temp[n - 1 + CONFLICT_FREE_OFFSET (n - 1 )];
317
317
temp[n - 1 + CONFLICT_FREE_OFFSET (n - 1 )] = 0 ;
318
- }
319
-
318
+ }
319
+
320
320
for (int d = 1 ; d < n; d *= 2 ) // traverse down tree & build scan
321
321
{
322
322
offset >>= 1 ;
@@ -350,7 +350,7 @@ __global__ void prescan_arbitrary(
350
350
int bankOffsetA = CONFLICT_FREE_OFFSET (ai);
351
351
int bankOffsetB = CONFLICT_FREE_OFFSET (bi);
352
352
353
-
353
+
354
354
if (threadID < n) {
355
355
temp[ai + bankOffsetA] = input[ai];
356
356
temp[bi + bankOffsetB] = input[bi];
@@ -359,11 +359,11 @@ __global__ void prescan_arbitrary(
359
359
temp[ai + bankOffsetA] = 0 ;
360
360
temp[bi + bankOffsetB] = 0 ;
361
361
}
362
-
362
+
363
363
364
364
int offset = 1 ;
365
365
// build sum in place up the tree
366
- for (int d = powerOfTwo >> 1 ; d > 0 ; d >>= 1 )
366
+ for (int d = powerOfTwo >> 1 ; d > 0 ; d >>= 1 )
367
367
{
368
368
__syncthreads ();
369
369
if (threadID < d)
@@ -380,7 +380,7 @@ __global__ void prescan_arbitrary(
380
380
381
381
if (threadID == 0 ) {
382
382
// clear the last element
383
- temp[powerOfTwo - 1 + CONFLICT_FREE_OFFSET (powerOfTwo - 1 )] = 0 ;
383
+ temp[powerOfTwo - 1 + CONFLICT_FREE_OFFSET (powerOfTwo - 1 )] = 0 ;
384
384
}
385
385
386
386
for (int d = 1 ; d < powerOfTwo; d *= 2 ) // traverse down tree & build scan
@@ -435,15 +435,15 @@ int nextPowerOfTwo(int x) {
435
435
void scanSmallDeviceArray (
436
436
int *d_out, const int * d_in, const int length, const cudaStream_t stream);
437
437
void scanLargeDeviceArray (
438
- int *d_out, const int * d_in, const int length, int *d_buf,
438
+ int *d_out, const int * d_in, const int length, int *d_buf,
439
439
const cudaStream_t stream);
440
440
void scanLargeEvenDeviceArray (
441
- int *d_out, const int * d_in, const int length, int *d_buf,
441
+ int *d_out, const int * d_in, const int length, int *d_buf,
442
442
const cudaStream_t stream);
443
443
444
444
void scanLargeEvenDeviceArray (
445
- int *d_out, const int * d_in, const int length, int *d_buf,
446
- const cudaStream_t stream)
445
+ int *d_out, const int * d_in, const int length, int *d_buf,
446
+ const cudaStream_t stream)
447
447
{
448
448
const int blocks = length / ELEMENTS_PER_BLOCK;
449
449
const int sharedMemArraySize = ELEMENTS_PER_BLOCK * sizeof (int );
@@ -471,18 +471,18 @@ void scanLargeEvenDeviceArray(
471
471
}
472
472
473
473
void scanSmallDeviceArray (
474
- int *d_out, const int * d_in, const int length, const cudaStream_t stream)
474
+ int *d_out, const int * d_in, const int length, const cudaStream_t stream)
475
475
{
476
476
int powerOfTwo = nextPowerOfTwo (length);
477
477
prescan_arbitrary
478
478
<<<1 , (length + 1 ) / 2 , 2 * powerOfTwo * sizeof (int ), stream >>> (
479
479
d_out, d_in, length, powerOfTwo);
480
480
}
481
481
482
- // /
482
+ // /
483
483
void scanLargeDeviceArray (
484
- int *d_out, const int * d_in, const int length, int *d_buf,
485
- const cudaStream_t stream)
484
+ int *d_out, const int * d_in, const int length, int *d_buf,
485
+ const cudaStream_t stream)
486
486
{
487
487
int remainder = length % (ELEMENTS_PER_BLOCK);
488
488
if (remainder == 0 ) {
@@ -493,20 +493,20 @@ void scanLargeDeviceArray(
493
493
int lengthMultiple = length - remainder ;
494
494
scanLargeEvenDeviceArray (d_out, d_in, lengthMultiple, d_buf, stream);
495
495
496
- // scan the remaining elements and add the (inclusive)
496
+ // scan the remaining elements and add the (inclusive)
497
497
// last element of the large scan to this
498
498
int *startOfOutputArray = &(d_out[lengthMultiple]);
499
499
scanSmallDeviceArray (
500
500
startOfOutputArray, &(d_in[lengthMultiple]), remainder , stream);
501
501
502
502
add<<<1 , remainder, 0 , stream>>> (
503
- startOfOutputArray, remainder , &(d_in[lengthMultiple - 1 ]),
503
+ startOfOutputArray, remainder , &(d_in[lengthMultiple - 1 ]),
504
504
&(d_out[lengthMultiple - 1 ]));
505
505
}
506
506
}
507
507
508
508
void exclusiveScan_kernelLauncher (
509
- int * d_out, const int * d_in, const int length, const cudaStream_t stream)
509
+ int * d_out, const int * d_in, const int length, const cudaStream_t stream)
510
510
{
511
511
if (length > ELEMENTS_PER_BLOCK) {
512
512
scanLargeDeviceArray (d_out, d_in, length, d_out + length, stream);
0 commit comments