Skip to content

Wrong attention mask implementation in BiMultiHeadAttention #12351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
xiuqhou opened this issue Apr 29, 2025 · 0 comments
Open

Wrong attention mask implementation in BiMultiHeadAttention #12351

xiuqhou opened this issue Apr 29, 2025 · 0 comments
Assignees

Comments

@xiuqhou
Copy link
Contributor

xiuqhou commented Apr 29, 2025

Describe the bug
In the implementation of BiMultiHeadAttention, a boolean attention_mask is filled with a small value to ignore unused elements:

https://github.com/open-mmlab/mmdetection/blob/main/mmdet/models/utils/vlfuse_helper.py#L200-#L201

            attention_mask = attention_mask.masked_fill(
                attention_mask == 0, -9e15)

            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError('Attention mask should be of '
                                 f'size {(bsz, 1, tgt_len, src_len)}')
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len,
                                             src_len)

However, there are two mistakes:

  1. In masked_fill, filling a non-zero float number into a boolean tensor will not convert its data type to torch.float automatically. Instead, it still returns a boolean tensor filled with True. When the attention_mask is added with attn_weights, all positions that should be ignored are instead added with 1 converted due to the filled True.
  2. In attention_mask, a True value indicates that the corresponding position is not allowed to attend. Therefore, the index to fill should be attention_mask == True, not attention_mask == 0.

Bug fix
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants