Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 3e535a7

Browse files
authored
More efficient MHA - faster padding (#407)
1 parent e2aaac9 commit 3e535a7

File tree

8 files changed

+157
-107
lines changed

8 files changed

+157
-107
lines changed

nestedtensor/csrc/BinaryOps.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ Tensor NestedTensor_add_Tensor(
3131
}
3232
}
3333
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
34-
if (!get_is_contiguous(self)) {
35-
self = NestedTensor_contiguous(self);
36-
}
34+
self = NestedTensor_contiguous(self);
3735
int64_t self_dim = get_dim(self);
3836
auto self_opt_sizes = get_opt_sizes(self);
3937
if (self_opt_sizes[self_dim - 1] && other.dim() == 1 &&

nestedtensor/csrc/cuda/mha.cpp

+25-96
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ at::Tensor bt_min_mha(
3636
TORCH_CHECK(get_dim(query) == 3, "query needs to be 3 dim.");
3737
TORCH_CHECK(get_dim(key) == 3, "key needs to be 3 dim.");
3838
TORCH_CHECK(get_dim(value) == 3, "value needs to be 3 dim.");
39+
TORCH_CHECK(get_nested_dim(query) == 1, "Query nested dim isn't 1.");
40+
TORCH_CHECK(get_nested_dim(key) == 1, "Key nested dim isn't 1.");
41+
TORCH_CHECK(get_nested_dim(value) == 1, "Value nested dim isn't 1.");
3942
// TORCH_CHECK(in_proj_bias, "Input projection bias needs to be defined.");
4043
// auto opt_sizes = get_opt_sizes(query);
4144
// if (!opt_sizes[2]) {
@@ -57,88 +60,31 @@ at::Tensor bt_min_mha(
5760
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
5861
at::cuda::setCurrentCUDAStream(defaultStream);
5962

60-
int64_t input_tensor_size = batch_size * head_num * seq_len * size_per_head;
61-
int64_t attn_tensor_size = batch_size * head_num * seq_len * seq_len;
62-
int word_num = batch_size * seq_len;
63-
Tensor prefix_sum = torch::zeros({word_num}, options);
64-
Tensor batch_idx = torch::zeros({word_num}, options);
65-
Tensor word_idx = torch::zeros({word_num}, options);
63+
at::Tensor packed = at::matmul(query, attr_kernel.t()) + attr_bias;
6664

67-
int* prefix_sum_ptr = prefix_sum.data_ptr<int>();
68-
int* batch_idx_ptr = batch_idx.data_ptr<int>();
69-
int* word_idx_ptr = word_idx.data_ptr<int>();
70-
71-
at::Tensor tmp = get_buffer(query);
72-
73-
auto query_esize = get_efficient_nested_size(query);
74-
TORCH_CHECK(query_esize.height() == 1, "Query nested dim isn't 1.");
75-
auto query_esize_sizes = query_esize.sizes();
76-
77-
at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(float_options);
78-
attr_mask = attr_mask * attr_mask.transpose(2, 3);
79-
80-
nteffectivetransformer::exclusiveScan_kernelLauncher(
81-
prefix_sum_ptr,
82-
input_mask.data_ptr<int>(),
83-
input_mask.size(0) * input_mask.size(1),
84-
defaultStream);
85-
86-
87-
nteffectivetransformer::compressBertInput_kernelLauncher(
88-
input_mask.data_ptr<int>(),
89-
prefix_sum_ptr,
90-
batch_idx_ptr,
91-
word_idx_ptr,
92-
(int32_t)(batch_size),
93-
(int32_t)(seq_len),
94-
(int32_t)(embedding_dim),
95-
defaultStream);
96-
97-
at::Tensor packed = at::matmul(query, attr_kernel.t());
65+
// TODO: Move into implementation of chunk for NestedTensor
9866
at::Tensor packed_buf = get_buffer(packed).contiguous().reshape({-1, 3 * embedding_dim});
9967
std::vector<at::Tensor> packed_chunks = packed_buf.chunk(3, -1);
100-
at::Tensor q_buf = packed_chunks[0].contiguous().reshape({-1});
101-
at::Tensor k_buf = packed_chunks[1].contiguous().reshape({-1});
102-
at::Tensor v_buf = packed_chunks[2].contiguous().reshape({-1});
103-
104-
int valid_word_num = get_numel(query) / embedding_dim;
105-
106-
at::Tensor query_buf = torch::zeros(
107-
{batch_size, head_num, seq_len, size_per_head}, float_options);
108-
at::Tensor key_buf = torch::zeros(
109-
{batch_size, head_num, seq_len, size_per_head}, float_options);
110-
at::Tensor val_buf = torch::zeros(
111-
{batch_size, head_num, seq_len, size_per_head}, float_options);
112-
at::Tensor attr_out =
113-
torch::zeros({valid_word_num, embedding_dim}, float_options);
114-
115-
std::vector<at::Tensor> bias_chunks = attr_bias.chunk(3);
116-
at::Tensor attr_bias_Q = bias_chunks[0];
117-
at::Tensor attr_bias_K = bias_chunks[1];
118-
at::Tensor attr_bias_V = bias_chunks[2];
119-
120-
nteffectivetransformer::cuda::add_QKV_bias_padding_kernelLauncher<float>(
121-
q_buf.data_ptr<float>(),
122-
attr_bias_Q.data_ptr<float>(),
123-
k_buf.data_ptr<float>(),
124-
attr_bias_K.data_ptr<float>(),
125-
v_buf.data_ptr<float>(),
126-
attr_bias_V.data_ptr<float>(),
127-
query_buf.data_ptr<float>(),
128-
key_buf.data_ptr<float>(),
129-
val_buf.data_ptr<float>(),
130-
valid_word_num,
131-
batch_size,
132-
seq_len,
133-
head_num,
134-
size_per_head,
135-
batch_idx_ptr,
136-
word_idx_ptr,
137-
defaultStream);
68+
at::Tensor q_buf_ = packed_chunks[0].contiguous().reshape({-1});
69+
at::Tensor k_buf_ = packed_chunks[1].contiguous().reshape({-1});
70+
at::Tensor v_buf_ = packed_chunks[2].contiguous().reshape({-1});
71+
at::Tensor q = wrap_buffer(std::move(q_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));
72+
at::Tensor k = wrap_buffer(std::move(k_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));
73+
at::Tensor v = wrap_buffer(std::move(v_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));
74+
75+
at::Tensor query_buf = to_padded_tensor(q, 0).contiguous();
76+
at::Tensor key_buf = to_padded_tensor(k, 0).contiguous();
77+
at::Tensor val_buf = to_padded_tensor(v, 0).contiguous();
78+
query_buf = query_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
79+
key_buf = key_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
80+
val_buf = val_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
13881

13982
key_buf = key_buf.transpose(2, 3);
14083
at::Tensor attn_output_weights = at::matmul(query_buf, key_buf).contiguous();
14184

85+
at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(float_options);
86+
attr_mask = attr_mask * attr_mask.transpose(2, 3);
87+
14288
nteffectivetransformer::cuda::softmax_kernel_kernelLauncher<float>(
14389
attn_output_weights.data_ptr<float>(),
14490
attr_mask.data_ptr<float>(),
@@ -148,27 +94,10 @@ at::Tensor bt_min_mha(
14894
(float)(scaling),
14995
defaultStream);
15096

151-
auto attn_output = at::matmul(attn_output_weights, val_buf);
152-
153-
nteffectivetransformer::cuda::transpose_rm_padding_kernelLauncher<float>(
154-
attn_output.data_ptr<float>(),
155-
attr_out.data_ptr<float>(),
156-
valid_word_num,
157-
batch_size,
158-
seq_len,
159-
head_num,
160-
size_per_head,
161-
batch_idx_ptr,
162-
word_idx_ptr,
163-
defaultStream);
164-
165-
// TODO: Bias is variably sized, need to add support for that.
166-
at::Tensor result = at::matmul(attr_out, out_proj_weight.t());
167-
result = result.reshape({-1});
168-
return wrap_buffer(
169-
std::move(result),
170-
get_efficient_nested_size(query),
171-
get_efficient_nested_stride(query));
97+
auto attn_output = at::matmul(attn_output_weights, val_buf).contiguous();
98+
attn_output = attn_output.transpose(1, 2).reshape({batch_size, seq_len, embedding_dim}).contiguous();
99+
at::Tensor attr_out = from_padded_tensor(attn_output, get_efficient_nested_size(query), get_efficient_nested_stride(query));
100+
return at::matmul(attr_out, out_proj_weight.t());
172101
}
173102

174103
TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {

nestedtensor/csrc/cuda/padding.cu

+70-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,18 @@ void add_padding(
1818
const int inner_size)
1919
{
2020
const int batch_id = blockIdx.x;
21-
for (int i = 0; i < (offsets[batch_id + 1] - offsets[batch_id]) * inner_size; i++) {
22-
output[batch_id * output_stride + i] = input[offsets[batch_id] * inner_size + i];
21+
const int grain_size = blockDim.x;
22+
const int tid = threadIdx.x;
23+
const int range = (offsets[batch_id + 1] - offsets[batch_id]) * inner_size;
24+
const int num_chunks = range / grain_size;
25+
for (int id = 0; id < num_chunks; id++) {
26+
output[batch_id * output_stride + id * grain_size + tid]
27+
= input[offsets[batch_id] * inner_size + id * grain_size + tid];
28+
}
29+
const int leftover = num_chunks * grain_size;
30+
if (leftover + tid < range) {
31+
output[batch_id * output_stride + leftover + tid]
32+
= input[offsets[batch_id] * inner_size + leftover + tid];
2333
}
2434
}
2535

@@ -36,7 +46,7 @@ void add_padding_kernelLauncher(
3646
dim3 grid;
3747
grid.x = batch_size;
3848

39-
add_padding<float><<<grid, 1, 0, stream>>>(
49+
add_padding<float><<<grid, 1024, 0, stream>>>(
4050
input,
4151
output,
4252
offsets,
@@ -111,5 +121,62 @@ template void add_padding_mask_kernelLauncher<float>(
111121
const int output_stride,
112122
const int inner_size,
113123
const cudaStream_t stream);
124+
125+
template<typename T>
126+
__global__
127+
void remove_padding(
128+
const T* input,
129+
T* output,
130+
const int* offsets,
131+
const int batch_size,
132+
const int output_stride,
133+
const int inner_size)
134+
{
135+
const int batch_id = blockIdx.x;
136+
const int grain_size = blockDim.x;
137+
const int tid = threadIdx.x;
138+
const int range = (offsets[batch_id + 1] - offsets[batch_id]) * inner_size;
139+
const int num_chunks = range / grain_size;
140+
for (int id = 0; id < num_chunks; id++) {
141+
output[offsets[batch_id] * inner_size + id * grain_size + tid]
142+
= input[batch_id * output_stride + id * grain_size + tid];
143+
}
144+
const int leftover = num_chunks * grain_size;
145+
if (leftover + tid < range) {
146+
output[offsets[batch_id] * inner_size + leftover + tid]
147+
= input[batch_id * output_stride + leftover + tid];
148+
}
149+
}
150+
151+
template<typename T>
152+
void remove_padding_kernelLauncher(
153+
T* input, // [batch_size x None]
154+
T* output, // [batch_size x max(input.nested_size(1)) x inner_size]
155+
const int* offsets, // [batch_size]
156+
const int batch_size,
157+
const int output_stride,
158+
const int inner_size,
159+
const cudaStream_t stream)
160+
{
161+
dim3 grid;
162+
grid.x = batch_size;
163+
164+
remove_padding<float><<<grid, 1024, 0, stream>>>(
165+
input,
166+
output,
167+
offsets,
168+
batch_size,
169+
output_stride,
170+
inner_size);
171+
}
172+
173+
template void remove_padding_kernelLauncher<float>(
174+
float* input,
175+
float* output,
176+
const int* offsets,
177+
const int batch_size,
178+
const int output_stride,
179+
const int inner_size,
180+
const cudaStream_t stream);
114181
}
115182
}

nestedtensor/csrc/cuda/padding.h

+10
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,15 @@ void add_padding_mask_kernelLauncher(
2828
const int inner_size,
2929
const cudaStream_t stream);
3030

31+
template <typename T>
32+
void remove_padding_kernelLauncher(
33+
T* input,
34+
T* output,
35+
const int* lengths,
36+
const int batch_size,
37+
const int output_stride,
38+
const int inner_size,
39+
const cudaStream_t stream);
40+
3141
}
3242
} // namespace nested_tensor

nestedtensor/csrc/masking.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,37 @@ Tensor to_mask(
391391
return merge_mask(res_mask, mask_dim);
392392
}
393393

394+
Tensor from_padded_tensor(Tensor padded, EfficientSizeNode target_size,
395+
EfficientSizeNode target_stride) {
396+
#ifdef WITH_CUDA
397+
if (padded.dim() == 3 && target_size.dim() == 3 && get_is_contiguous(padded)) {
398+
auto nt_opt_size = target_size.opt_sizes();
399+
if (nt_opt_size[2] && padded.is_cuda()) {
400+
Tensor nt_sizes_ = target_size.sizes().to(torch::kInt32);
401+
TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor must be of nested_dim 2.")
402+
Tensor nt_sizes = at::native::narrow(nt_sizes_, 1, 0, 1);
403+
int max_size_1 = nt_sizes.max().item<int>();
404+
nt_sizes =
405+
at::native::cumsum(nt_sizes, 0).to(torch::kInt32).reshape({-1});
406+
nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes});
407+
Tensor output = torch::empty({target_size.numel()}, padded.options());
408+
nt_sizes = nt_sizes.to(torch::kCUDA);
409+
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
410+
nested_tensor::cuda::remove_padding_kernelLauncher(
411+
padded.data_ptr<float>(),
412+
output.data_ptr<float>(),
413+
nt_sizes.data_ptr<int>(),
414+
*nt_opt_size[0],
415+
padded.stride(0),
416+
*nt_opt_size[2],
417+
defaultStream);
418+
return wrap_buffer(std::move(output), target_size, target_stride);
419+
}
420+
}
421+
#endif
422+
TORCH_CHECK(false, "from_padded_tensor not implemented for this case.");
423+
}
424+
394425
Tensor to_padded_tensor(Tensor nt, double padding) {
395426
#ifdef WITH_CUDA
396427
if (get_dim(nt) == 3 && get_is_contiguous(nt)) {

nestedtensor/csrc/masking.h

+10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <nestedtensor/csrc/python_functions.h>
55
#include <nestedtensor/csrc/utils/nested_node_functions.h>
66
#include <nestedtensor/csrc/utils/python_nested_node.h>
7+
#include <nestedtensor/csrc/storage/EfficientSizeNode.h>
78
#include <torch/csrc/Size.h>
89
#include <torch/csrc/autograd/python_variable_indexing.h>
910
#include <torch/extension.h>
@@ -16,6 +17,15 @@ at::Tensor to_mask(
1617
at::Tensor nt,
1718
c10::optional<int64_t> mask_dim);
1819

20+
at::Tensor to_padded_tensor(
21+
at::Tensor nt,
22+
double padding);
23+
24+
at::Tensor from_padded_tensor(
25+
at::Tensor nt,
26+
torch::nested_tensor::EfficientSizeNode target_size,
27+
torch::nested_tensor::EfficientSizeNode target_stride);
28+
1929
c10::optional<at::Tensor> nt_from_tensor_mask(
2030
at::Tensor tensor,
2131
at::Tensor mask,

nestedtensor/csrc/storage/EfficientSizeNode.h

+8-3
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ struct EfficientSizeNode {
133133
const std::vector<c10::optional<int64_t>>& opt_sizes() const {
134134
return _opt_sizes;
135135
}
136+
void refresh_opt_sizes() {
137+
_opt_sizes = impl::construct_efficient_size(_structure, _height, _sizes);
138+
}
136139
const at::Tensor& sizes() const {
137140
return _sizes;
138141
}
@@ -167,7 +170,7 @@ struct EfficientSizeNode {
167170
std::vector<int64_t> _structure;
168171
const at::Tensor _sizes;
169172
bool _opt_sizes_set = false;
170-
const std::vector<c10::optional<int64_t>> _opt_sizes;
173+
std::vector<c10::optional<int64_t>> _opt_sizes;
171174
};
172175

173176
inline bool efficient_size_structure_matches(
@@ -230,10 +233,12 @@ inline void apply_efficient_size(
230233
}
231234
for (int64_t i = 0; i < sizes0.size(0); i++) {
232235
fn(sizes0_ptr + i * sizes0.size(1),
233-
sizes0.size(0),
236+
sizes0.size(1),
234237
sizes1_ptr + i * sizes1.size(1),
235-
sizes1.size(0));
238+
sizes1.size(1));
236239
}
240+
size_node0.refresh_opt_sizes();
241+
size_node1.refresh_opt_sizes();
237242
}
238243

239244
} // namespace nested_tensor

nestedtensor/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.1.4+e1d384f'
2-
git_version = 'e1d384fea9d70a664b38a53768f82c81057a7d13'
1+
__version__ = '0.1.4+3a8fd81'
2+
git_version = '3a8fd81e999271b1ecdbf6cad8d1b6e1718d00c7'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

0 commit comments

Comments
 (0)