You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
attention_mask=attention_mask.masked_fill(
attention_mask==0, -9e15)
ifattention_mask.size() != (bsz, 1, tgt_len, src_len):
raiseValueError('Attention mask should be of 'f'size {(bsz, 1, tgt_len, src_len)}')
attn_weights=attn_weights.view(bsz, self.num_heads, tgt_len,
src_len) +attention_maskattn_weights=attn_weights.view(bsz*self.num_heads, tgt_len,
src_len)
However, there are two mistakes:
In masked_fill, filling a non-zero float number into a boolean tensor will not convert its data type to torch.float automatically. Instead, it still returns a boolean tensor filled with True. When the attention_mask is added with attn_weights, all positions that should be ignored are instead added with 1 converted due to the filled True.
In attention_mask, a True value indicates that the corresponding position is not allowed to attend. Therefore, the index to fill should be attention_mask == True, not attention_mask == 0.
Bug fix
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
The text was updated successfully, but these errors were encountered:
Describe the bug
In the implementation of BiMultiHeadAttention, a boolean
attention_mask
is filled with a small value to ignore unused elements:https://github.com/open-mmlab/mmdetection/blob/main/mmdet/models/utils/vlfuse_helper.py#L200-#L201
However, there are two mistakes:
masked_fill
, filling a non-zero float number into a boolean tensor will not convert its data type totorch.float
automatically. Instead, it still returns a boolean tensor filled withTrue
. When theattention_mask
is added withattn_weights
, all positions that should be ignored are instead added with1
converted due to the filledTrue
.attention_mask
, aTrue
value indicates that the corresponding position is not allowed to attend. Therefore, the index to fill should beattention_mask == True
, notattention_mask == 0
.Bug fix
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
The text was updated successfully, but these errors were encountered: