@@ -36,6 +36,9 @@ at::Tensor bt_min_mha(
36
36
TORCH_CHECK (get_dim (query) == 3 , " query needs to be 3 dim." );
37
37
TORCH_CHECK (get_dim (key) == 3 , " key needs to be 3 dim." );
38
38
TORCH_CHECK (get_dim (value) == 3 , " value needs to be 3 dim." );
39
+ TORCH_CHECK (get_nested_dim (query) == 1 , " Query nested dim isn't 1." );
40
+ TORCH_CHECK (get_nested_dim (key) == 1 , " Key nested dim isn't 1." );
41
+ TORCH_CHECK (get_nested_dim (value) == 1 , " Value nested dim isn't 1." );
39
42
// TORCH_CHECK(in_proj_bias, "Input projection bias needs to be defined.");
40
43
// auto opt_sizes = get_opt_sizes(query);
41
44
// if (!opt_sizes[2]) {
@@ -57,88 +60,31 @@ at::Tensor bt_min_mha(
57
60
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream ();
58
61
at::cuda::setCurrentCUDAStream (defaultStream);
59
62
60
- int64_t input_tensor_size = batch_size * head_num * seq_len * size_per_head;
61
- int64_t attn_tensor_size = batch_size * head_num * seq_len * seq_len;
62
- int word_num = batch_size * seq_len;
63
- Tensor prefix_sum = torch::zeros ({word_num}, options);
64
- Tensor batch_idx = torch::zeros ({word_num}, options);
65
- Tensor word_idx = torch::zeros ({word_num}, options);
63
+ at::Tensor packed = at::matmul (query, attr_kernel.t ()) + attr_bias;
66
64
67
- int * prefix_sum_ptr = prefix_sum.data_ptr <int >();
68
- int * batch_idx_ptr = batch_idx.data_ptr <int >();
69
- int * word_idx_ptr = word_idx.data_ptr <int >();
70
-
71
- at::Tensor tmp = get_buffer (query);
72
-
73
- auto query_esize = get_efficient_nested_size (query);
74
- TORCH_CHECK (query_esize.height () == 1 , " Query nested dim isn't 1." );
75
- auto query_esize_sizes = query_esize.sizes ();
76
-
77
- at::Tensor attr_mask = input_mask.view ({-1 , 1 , 1 , seq_len}).to (float_options);
78
- attr_mask = attr_mask * attr_mask.transpose (2 , 3 );
79
-
80
- nteffectivetransformer::exclusiveScan_kernelLauncher (
81
- prefix_sum_ptr,
82
- input_mask.data_ptr <int >(),
83
- input_mask.size (0 ) * input_mask.size (1 ),
84
- defaultStream);
85
-
86
-
87
- nteffectivetransformer::compressBertInput_kernelLauncher (
88
- input_mask.data_ptr <int >(),
89
- prefix_sum_ptr,
90
- batch_idx_ptr,
91
- word_idx_ptr,
92
- (int32_t )(batch_size),
93
- (int32_t )(seq_len),
94
- (int32_t )(embedding_dim),
95
- defaultStream);
96
-
97
- at::Tensor packed = at::matmul (query, attr_kernel.t ());
65
+ // TODO: Move into implementation of chunk for NestedTensor
98
66
at::Tensor packed_buf = get_buffer (packed).contiguous ().reshape ({-1 , 3 * embedding_dim});
99
67
std::vector<at::Tensor> packed_chunks = packed_buf.chunk (3 , -1 );
100
- at::Tensor q_buf = packed_chunks[0 ].contiguous ().reshape ({-1 });
101
- at::Tensor k_buf = packed_chunks[1 ].contiguous ().reshape ({-1 });
102
- at::Tensor v_buf = packed_chunks[2 ].contiguous ().reshape ({-1 });
103
-
104
- int valid_word_num = get_numel (query) / embedding_dim;
105
-
106
- at::Tensor query_buf = torch::zeros (
107
- {batch_size, head_num, seq_len, size_per_head}, float_options);
108
- at::Tensor key_buf = torch::zeros (
109
- {batch_size, head_num, seq_len, size_per_head}, float_options);
110
- at::Tensor val_buf = torch::zeros (
111
- {batch_size, head_num, seq_len, size_per_head}, float_options);
112
- at::Tensor attr_out =
113
- torch::zeros ({valid_word_num, embedding_dim}, float_options);
114
-
115
- std::vector<at::Tensor> bias_chunks = attr_bias.chunk (3 );
116
- at::Tensor attr_bias_Q = bias_chunks[0 ];
117
- at::Tensor attr_bias_K = bias_chunks[1 ];
118
- at::Tensor attr_bias_V = bias_chunks[2 ];
119
-
120
- nteffectivetransformer::cuda::add_QKV_bias_padding_kernelLauncher<float >(
121
- q_buf.data_ptr <float >(),
122
- attr_bias_Q.data_ptr <float >(),
123
- k_buf.data_ptr <float >(),
124
- attr_bias_K.data_ptr <float >(),
125
- v_buf.data_ptr <float >(),
126
- attr_bias_V.data_ptr <float >(),
127
- query_buf.data_ptr <float >(),
128
- key_buf.data_ptr <float >(),
129
- val_buf.data_ptr <float >(),
130
- valid_word_num,
131
- batch_size,
132
- seq_len,
133
- head_num,
134
- size_per_head,
135
- batch_idx_ptr,
136
- word_idx_ptr,
137
- defaultStream);
68
+ at::Tensor q_buf_ = packed_chunks[0 ].contiguous ().reshape ({-1 });
69
+ at::Tensor k_buf_ = packed_chunks[1 ].contiguous ().reshape ({-1 });
70
+ at::Tensor v_buf_ = packed_chunks[2 ].contiguous ().reshape ({-1 });
71
+ at::Tensor q = wrap_buffer (std::move (q_buf_), get_efficient_nested_size (query), get_efficient_nested_stride (query));
72
+ at::Tensor k = wrap_buffer (std::move (k_buf_), get_efficient_nested_size (query), get_efficient_nested_stride (query));
73
+ at::Tensor v = wrap_buffer (std::move (v_buf_), get_efficient_nested_size (query), get_efficient_nested_stride (query));
74
+
75
+ at::Tensor query_buf = to_padded_tensor (q, 0 ).contiguous ();
76
+ at::Tensor key_buf = to_padded_tensor (k, 0 ).contiguous ();
77
+ at::Tensor val_buf = to_padded_tensor (v, 0 ).contiguous ();
78
+ query_buf = query_buf.reshape ({batch_size, seq_len, head_num, size_per_head}).transpose (1 , 2 );
79
+ key_buf = key_buf.reshape ({batch_size, seq_len, head_num, size_per_head}).transpose (1 , 2 );
80
+ val_buf = val_buf.reshape ({batch_size, seq_len, head_num, size_per_head}).transpose (1 , 2 );
138
81
139
82
key_buf = key_buf.transpose (2 , 3 );
140
83
at::Tensor attn_output_weights = at::matmul (query_buf, key_buf).contiguous ();
141
84
85
+ at::Tensor attr_mask = input_mask.view ({-1 , 1 , 1 , seq_len}).to (float_options);
86
+ attr_mask = attr_mask * attr_mask.transpose (2 , 3 );
87
+
142
88
nteffectivetransformer::cuda::softmax_kernel_kernelLauncher<float >(
143
89
attn_output_weights.data_ptr <float >(),
144
90
attr_mask.data_ptr <float >(),
@@ -148,27 +94,10 @@ at::Tensor bt_min_mha(
148
94
(float )(scaling),
149
95
defaultStream);
150
96
151
- auto attn_output = at::matmul (attn_output_weights, val_buf);
152
-
153
- nteffectivetransformer::cuda::transpose_rm_padding_kernelLauncher<float >(
154
- attn_output.data_ptr <float >(),
155
- attr_out.data_ptr <float >(),
156
- valid_word_num,
157
- batch_size,
158
- seq_len,
159
- head_num,
160
- size_per_head,
161
- batch_idx_ptr,
162
- word_idx_ptr,
163
- defaultStream);
164
-
165
- // TODO: Bias is variably sized, need to add support for that.
166
- at::Tensor result = at::matmul (attr_out, out_proj_weight.t ());
167
- result = result.reshape ({-1 });
168
- return wrap_buffer (
169
- std::move (result),
170
- get_efficient_nested_size (query),
171
- get_efficient_nested_stride (query));
97
+ auto attn_output = at::matmul (attn_output_weights, val_buf).contiguous ();
98
+ attn_output = attn_output.transpose (1 , 2 ).reshape ({batch_size, seq_len, embedding_dim}).contiguous ();
99
+ at::Tensor attr_out = from_padded_tensor (attn_output, get_efficient_nested_size (query), get_efficient_nested_stride (query));
100
+ return at::matmul (attr_out, out_proj_weight.t ());
172
101
}
173
102
174
103
TORCH_LIBRARY_FRAGMENT (nestedtensor, m) {
0 commit comments