diff --git a/nestedtensor/csrc/cuda/cuda_kernels.cu b/nestedtensor/csrc/cuda/cuda_kernels.cu index 97868b67..2428cc0e 100644 --- a/nestedtensor/csrc/cuda/cuda_kernels.cu +++ b/nestedtensor/csrc/cuda/cuda_kernels.cu @@ -24,18 +24,18 @@ namespace nteffectivetransformer{ -// gelu code from +// gelu code from // https://github.com/NVIDIA/DeepLearningExamples/blob/master/FasterTransformer/v1/fastertransformer/cuda/cuda_kernels.cu#L26-L45 template __inline__ __device__ T gelu(T x) { - float cdf = 0.5f * + float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); return x * cdf; } -// reduce code from +// reduce code from // https://github.com/NVIDIA/DeepLearningExamples/blob/master/FasterTransformer/v1/fastertransformer/cuda/cuda_kernels.cu#L47-L73 #define FINAL_MASK 0xffffffff @@ -53,9 +53,9 @@ template __inline__ __device__ T blockReduceSum(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; val = warpReduceSum(val); @@ -71,7 +71,7 @@ T blockReduceSum(T val) /// ***************************** add_bias + gelu ***************************** template -__global__ +__global__ void add_bias_act(T* out, const T* bias, int m, int n) { T val, reg_bias; @@ -112,9 +112,9 @@ template void add_bias_act_kernelLauncher( /// ************************** add_bias + layer_norm ************************** template -__global__ +__global__ void add_bias_input_layernorm( - T* out, const T* input, const T* bias, const T* gamma, + T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n) { int tid = threadIdx.x; @@ -126,7 +126,7 @@ void add_bias_input_layernorm( float local_out = 0.0f; for(int i = tid; i < n; i += blockDim.x) - local_out += (float)(out[blockIdx.x * n + i] + local_out += (float)(out[blockIdx.x * n + i] + input[blockIdx.x * n + i] + __ldg(&bias[i])); mean = blockReduceSum(local_out); @@ -141,14 +141,14 @@ void add_bias_input_layernorm( __syncthreads(); for(int i = tid; i < n; i += blockDim.x) - out[blockIdx.x * n + i] = - (T)(((local_out - s_mean) * rsqrtf(s_variance)) + out[blockIdx.x * n + i] = + (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[i])) + (float)(__ldg(&beta[i]))); } template void add_bias_input_layernorm_kernelLauncher( - T* out, const T* input, const T* bias, + T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n, cudaStream_t stream) { assert(n < 1024); @@ -159,8 +159,8 @@ void add_bias_input_layernorm_kernelLauncher( } template void add_bias_input_layernorm_kernelLauncher( - float* out, const float* input, - const float* bias, const float* gamma, const float* beta, + float* out, const float* input, + const float* bias, const float* gamma, const float* beta, int m, int n, cudaStream_t stream); /// *********************************** fin *********************************** @@ -168,19 +168,19 @@ template void add_bias_input_layernorm_kernelLauncher( /// *********************** compresse transformer input *********************** -__global__ +__global__ void compress_bert_input( // const T* from_tensor, - const int* mask, const int* prefix_sum, + const int* mask, const int* prefix_sum, // T* to_tensor, int* batch_idx, int* word_idx, - int batch_size , int seq_len, int hidden_dim) + int batch_size , int seq_len, int hidden_dim) { int bid = blockIdx.y; // batch - int wid = blockIdx.x; // word - int tid = threadIdx.x; // - - /// 1. count pos for from tensor + int wid = blockIdx.x; // word + int tid = threadIdx.x; // + + /// 1. count pos for from tensor int mask_idx = bid * seq_len + wid; if (mask[mask_idx] > 0.5) { @@ -191,7 +191,7 @@ void compress_bert_input( batch_idx[valid_idx] = bid; word_idx[valid_idx] = wid; } - + // /// 3. copy src data // float* src_ptr = (float*)from_tensor; // float* dst_ptr = (float*)to_tensor; @@ -203,10 +203,10 @@ void compress_bert_input( void compressBertInput_kernelLauncher( // const T* from_tensor, - const int* mask, const int* prefix_sum, + const int* mask, const int* prefix_sum, // T* to_tensor, int* batch_idx, int* word_idx, - int batch_size , int seq_len, int hidden_dim, cudaStream_t stream) + int batch_size , int seq_len, int hidden_dim, cudaStream_t stream) { /// TODO : fp32 dim3 grid(seq_len, batch_size); @@ -215,7 +215,7 @@ void compressBertInput_kernelLauncher( assert(hidden_dim <= 1024); compress_bert_input<<>>( // from_tensor, - mask, prefix_sum, + mask, prefix_sum, // to_tensor, batch_idx, word_idx, batch_size , seq_len, hidden_dim); @@ -229,11 +229,11 @@ template __global__ void restore_bert_output( T* to_tensor, - const T* from_tensor, const int* batch_idx, const int* word_idx, - int valid_word_num, int seq_len, int hidden_dim) + const T* from_tensor, const int* batch_idx, const int* word_idx, + int valid_word_num, int seq_len, int hidden_dim) { int bid = batch_idx[blockIdx.x]; - int wid = word_idx[blockIdx.x]; + int wid = word_idx[blockIdx.x]; int tid = threadIdx.x; int vid = blockIdx.x; @@ -248,24 +248,24 @@ void restore_bert_output( template void restoreBertOutput_kernelLauncher( T* to_tensor, - const T* from_tensor, const int* batch_idx, const int* word_idx, - int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream) + const T* from_tensor, const int* batch_idx, const int* word_idx, + int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream) { // TODO : fp32 dim3 grid(valid_word_num); dim3 block(hidden_dim); assert(hidden_dim <= 1024); restore_bert_output<<>>( - to_tensor, + to_tensor, from_tensor, batch_idx, word_idx, valid_word_num, seq_len, hidden_dim); } template void restoreBertOutput_kernelLauncher( float* to_tensor, - const float* from_tensor, const int* batch_idx, const int* word_idx, + const float* from_tensor, const int* batch_idx, const int* word_idx, int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream); - + /// *********************************** fin *********************************** /// ***************************** exclusive scan ****************************** @@ -279,14 +279,14 @@ int ELEMENTS_PER_BLOCK = THREADS_PER_BLOCK * 2; #define LOG_MEM_BANKS 5 #define CONFLICT_FREE_OFFSET(n) ((n) >> LOG_MEM_BANKS) -__global__ void prescan_large(int *output, const int *input, int n, int *sums) +__global__ void prescan_large(int *output, const int *input, int n, int *sums) { extern __shared__ int temp[]; int blockID = blockIdx.x; int threadID = threadIdx.x; int blockOffset = blockID * n; - + int ai = threadID; int bi = threadID + (n / 2); int bankOffsetA = CONFLICT_FREE_OFFSET(ai); @@ -312,11 +312,11 @@ __global__ void prescan_large(int *output, const int *input, int n, int *sums) __syncthreads(); - if (threadID == 0) { + if (threadID == 0) { sums[blockID] = temp[n - 1 + CONFLICT_FREE_OFFSET(n - 1)]; temp[n - 1 + CONFLICT_FREE_OFFSET(n - 1)] = 0; - } - + } + for (int d = 1; d < n; d *= 2) // traverse down tree & build scan { offset >>= 1; @@ -350,7 +350,7 @@ __global__ void prescan_arbitrary( int bankOffsetA = CONFLICT_FREE_OFFSET(ai); int bankOffsetB = CONFLICT_FREE_OFFSET(bi); - + if (threadID < n) { temp[ai + bankOffsetA] = input[ai]; temp[bi + bankOffsetB] = input[bi]; @@ -359,11 +359,11 @@ __global__ void prescan_arbitrary( temp[ai + bankOffsetA] = 0; temp[bi + bankOffsetB] = 0; } - + int offset = 1; // build sum in place up the tree - for (int d = powerOfTwo >> 1; d > 0; d >>= 1) + for (int d = powerOfTwo >> 1; d > 0; d >>= 1) { __syncthreads(); if (threadID < d) @@ -380,7 +380,7 @@ __global__ void prescan_arbitrary( if (threadID == 0) { // clear the last element - temp[powerOfTwo - 1 + CONFLICT_FREE_OFFSET(powerOfTwo - 1)] = 0; + temp[powerOfTwo - 1 + CONFLICT_FREE_OFFSET(powerOfTwo - 1)] = 0; } for (int d = 1; d < powerOfTwo; d *= 2) // traverse down tree & build scan @@ -435,15 +435,15 @@ int nextPowerOfTwo(int x) { void scanSmallDeviceArray( int *d_out, const int* d_in, const int length, const cudaStream_t stream); void scanLargeDeviceArray( - int *d_out, const int* d_in, const int length, int *d_buf, + int *d_out, const int* d_in, const int length, int *d_buf, const cudaStream_t stream); void scanLargeEvenDeviceArray( - int *d_out, const int* d_in, const int length, int *d_buf, + int *d_out, const int* d_in, const int length, int *d_buf, const cudaStream_t stream); void scanLargeEvenDeviceArray( - int *d_out, const int* d_in, const int length, int *d_buf, - const cudaStream_t stream) + int *d_out, const int* d_in, const int length, int *d_buf, + const cudaStream_t stream) { const int blocks = length / ELEMENTS_PER_BLOCK; const int sharedMemArraySize = ELEMENTS_PER_BLOCK * sizeof(int); @@ -471,7 +471,7 @@ void scanLargeEvenDeviceArray( } void scanSmallDeviceArray( - int *d_out, const int* d_in, const int length, const cudaStream_t stream) + int *d_out, const int* d_in, const int length, const cudaStream_t stream) { int powerOfTwo = nextPowerOfTwo(length); prescan_arbitrary @@ -479,10 +479,10 @@ void scanSmallDeviceArray( d_out, d_in, length, powerOfTwo); } -/// +/// void scanLargeDeviceArray( - int *d_out, const int* d_in, const int length, int *d_buf, - const cudaStream_t stream) + int *d_out, const int* d_in, const int length, int *d_buf, + const cudaStream_t stream) { int remainder = length % (ELEMENTS_PER_BLOCK); if (remainder == 0) { @@ -493,20 +493,20 @@ void scanLargeDeviceArray( int lengthMultiple = length - remainder; scanLargeEvenDeviceArray(d_out, d_in, lengthMultiple, d_buf, stream); - // scan the remaining elements and add the (inclusive) + // scan the remaining elements and add the (inclusive) // last element of the large scan to this int *startOfOutputArray = &(d_out[lengthMultiple]); scanSmallDeviceArray( startOfOutputArray, &(d_in[lengthMultiple]), remainder, stream); add<<<1, remainder, 0, stream>>>( - startOfOutputArray, remainder, &(d_in[lengthMultiple - 1]), + startOfOutputArray, remainder, &(d_in[lengthMultiple - 1]), &(d_out[lengthMultiple - 1])); } } void exclusiveScan_kernelLauncher( - int* d_out, const int* d_in, const int length, const cudaStream_t stream) + int* d_out, const int* d_in, const int length, const cudaStream_t stream) { if (length > ELEMENTS_PER_BLOCK) { scanLargeDeviceArray(d_out, d_in, length, d_out + length, stream);