Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit e2aaac9

Browse files
authored
More efficient MHA - recycle input_mask for sequence mask (#406)
1 parent e1d384f commit e2aaac9

File tree

2 files changed

+4
-17
lines changed

2 files changed

+4
-17
lines changed

nestedtensor/csrc/cuda/mha.cpp

+2-15
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,6 @@ using namespace at;
1919
namespace torch {
2020
namespace nested_tensor {
2121

22-
at::Tensor _sequence_mask(at::Tensor lengths) {
23-
int64_t batch_size = lengths.numel();
24-
int64_t max_len = lengths.max().item<int64_t>();
25-
at::Tensor mask = torch::arange(0, max_len, torch::kFloat);
26-
mask = mask.repeat({batch_size, 1});
27-
mask = mask.lt(lengths.unsqueeze(1));
28-
mask = mask.to(torch::kCUDA);
29-
mask = mask.view({-1, 1, 1, max_len});
30-
at::Tensor m2 = mask.transpose(2, 3);
31-
return mask * m2;
32-
}
33-
3422
at::Tensor bt_min_mha(
3523
int64_t num_heads,
3624
int64_t head_dim,
@@ -86,9 +74,8 @@ at::Tensor bt_min_mha(
8674
TORCH_CHECK(query_esize.height() == 1, "Query nested dim isn't 1.");
8775
auto query_esize_sizes = query_esize.sizes();
8876

89-
at::Tensor attr_mask = _sequence_mask(
90-
at::native::select(query_esize_sizes, 1, 0).contiguous());
91-
attr_mask = attr_mask.to(float_options);
77+
at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(float_options);
78+
attr_mask = attr_mask * attr_mask.transpose(2, 3);
9279

9380
nteffectivetransformer::exclusiveScan_kernelLauncher(
9481
prefix_sum_ptr,

nestedtensor/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.1.4+f20ca2f'
2-
git_version = 'f20ca2f38aaf234c1c5b85fc3b07fbe2e291cea5'
1+
__version__ = '0.1.4+e1d384f'
2+
git_version = 'e1d384fea9d70a664b38a53768f82c81057a7d13'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

0 commit comments

Comments
 (0)