|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +""" |
| 15 | +Copied from Dino repo. https://github.com/facebookresearch/dino |
| 16 | +Mostly copy-paste from timm library. |
| 17 | +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py |
| 18 | +""" |
| 19 | +import math |
| 20 | +from functools import partial |
| 21 | + |
| 22 | +import torch |
| 23 | +import torch.nn as nn |
| 24 | + |
| 25 | +def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
| 26 | + # Cut & paste from PyTorch official master until it's in a few official releases - RW |
| 27 | + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf |
| 28 | + def norm_cdf(x): |
| 29 | + # Computes standard normal cumulative distribution function |
| 30 | + return (1. + math.erf(x / math.sqrt(2.))) / 2. |
| 31 | + |
| 32 | + if (mean < a - 2 * std) or (mean > b + 2 * std): |
| 33 | + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| 34 | + "The distribution of values may be incorrect.", |
| 35 | + stacklevel=2) |
| 36 | + |
| 37 | + with torch.no_grad(): |
| 38 | + # Values are generated by using a truncated uniform distribution and |
| 39 | + # then using the inverse CDF for the normal distribution. |
| 40 | + # Get upper and lower cdf values |
| 41 | + l = norm_cdf((a - mean) / std) |
| 42 | + u = norm_cdf((b - mean) / std) |
| 43 | + |
| 44 | + # Uniformly fill tensor with values from [l, u], then translate to |
| 45 | + # [2l-1, 2u-1]. |
| 46 | + tensor.uniform_(2 * l - 1, 2 * u - 1) |
| 47 | + |
| 48 | + # Use inverse cdf transform for normal distribution to get truncated |
| 49 | + # standard normal |
| 50 | + tensor.erfinv_() |
| 51 | + |
| 52 | + # Transform to proper mean, std |
| 53 | + tensor.mul_(std * math.sqrt(2.)) |
| 54 | + tensor.add_(mean) |
| 55 | + |
| 56 | + # Clamp to ensure it's in the proper range |
| 57 | + tensor.clamp_(min=a, max=b) |
| 58 | + return tensor |
| 59 | + |
| 60 | + |
| 61 | +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| 62 | + # type: (Tensor, float, float, float, float) -> Tensor |
| 63 | + return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
| 64 | + |
| 65 | + |
| 66 | +def drop_path(x, drop_prob: float = 0., training: bool = False): |
| 67 | + if drop_prob == 0. or not training: |
| 68 | + return x |
| 69 | + keep_prob = 1 - drop_prob |
| 70 | + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets |
| 71 | + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| 72 | + random_tensor.floor_() # binarize |
| 73 | + output = x.div(keep_prob) * random_tensor |
| 74 | + return output |
| 75 | + |
| 76 | + |
| 77 | +class DropPath(nn.Module): |
| 78 | + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| 79 | + """ |
| 80 | + def __init__(self, drop_prob=None): |
| 81 | + super(DropPath, self).__init__() |
| 82 | + self.drop_prob = drop_prob |
| 83 | + |
| 84 | + def forward(self, x): |
| 85 | + return drop_path(x, self.drop_prob, self.training) |
| 86 | + |
| 87 | + |
| 88 | +class Mlp(nn.Module): |
| 89 | + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| 90 | + super().__init__() |
| 91 | + out_features = out_features or in_features |
| 92 | + hidden_features = hidden_features or in_features |
| 93 | + self.fc1 = nn.Linear(in_features, hidden_features) |
| 94 | + self.act = act_layer() |
| 95 | + self.fc2 = nn.Linear(hidden_features, out_features) |
| 96 | + self.drop = nn.Dropout(drop) |
| 97 | + |
| 98 | + def forward(self, x): |
| 99 | + x = self.fc1(x) |
| 100 | + x = self.act(x) |
| 101 | + x = self.drop(x) |
| 102 | + x = self.fc2(x) |
| 103 | + x = self.drop(x) |
| 104 | + return x |
| 105 | + |
| 106 | + |
| 107 | +class Attention(nn.Module): |
| 108 | + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
| 109 | + super().__init__() |
| 110 | + self.num_heads = num_heads |
| 111 | + head_dim = dim // num_heads |
| 112 | + self.scale = qk_scale or head_dim ** -0.5 |
| 113 | + |
| 114 | + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| 115 | + self.attn_drop = nn.Dropout(attn_drop) |
| 116 | + self.proj = nn.Linear(dim, dim) |
| 117 | + self.proj_drop = nn.Dropout(proj_drop) |
| 118 | + |
| 119 | + def forward(self, x): |
| 120 | + B, N, C = x.shape |
| 121 | + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| 122 | + q, k, v = qkv[0], qkv[1], qkv[2] |
| 123 | + |
| 124 | + attn = (q @ k.transpose(-2, -1)) * self.scale |
| 125 | + attn = attn.softmax(dim=-1) |
| 126 | + attn = self.attn_drop(attn) |
| 127 | + |
| 128 | + x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| 129 | + x = self.proj(x) |
| 130 | + x = self.proj_drop(x) |
| 131 | + return x, attn |
| 132 | + |
| 133 | + |
| 134 | +class Block(nn.Module): |
| 135 | + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
| 136 | + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
| 137 | + super().__init__() |
| 138 | + self.norm1 = norm_layer(dim) |
| 139 | + self.attn = Attention( |
| 140 | + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
| 141 | + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| 142 | + self.norm2 = norm_layer(dim) |
| 143 | + mlp_hidden_dim = int(dim * mlp_ratio) |
| 144 | + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
| 145 | + |
| 146 | + def forward(self, x, return_attention=False): |
| 147 | + y, attn = self.attn(self.norm1(x)) |
| 148 | + if return_attention: |
| 149 | + return attn |
| 150 | + x = x + self.drop_path(y) |
| 151 | + x = x + self.drop_path(self.mlp(self.norm2(x))) |
| 152 | + return x |
| 153 | + |
| 154 | + |
| 155 | +class PatchEmbed(nn.Module): |
| 156 | + """ Image to Patch Embedding |
| 157 | + """ |
| 158 | + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): |
| 159 | + super().__init__() |
| 160 | + num_patches = (img_size // patch_size) * (img_size // patch_size) |
| 161 | + self.img_size = img_size |
| 162 | + self.patch_size = patch_size |
| 163 | + self.num_patches = num_patches |
| 164 | + |
| 165 | + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
| 166 | + |
| 167 | + def forward(self, x): |
| 168 | + B, C, H, W = x.shape |
| 169 | + x = self.proj(x).flatten(2).transpose(1, 2) |
| 170 | + return x |
| 171 | + |
| 172 | + |
| 173 | +class VisionTransformer(nn.Module): |
| 174 | + """ Vision Transformer """ |
| 175 | + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, |
| 176 | + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., |
| 177 | + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): |
| 178 | + super().__init__() |
| 179 | + self.num_features = self.embed_dim = embed_dim |
| 180 | + |
| 181 | + self.patch_embed = PatchEmbed( |
| 182 | + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) |
| 183 | + num_patches = self.patch_embed.num_patches |
| 184 | + |
| 185 | + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| 186 | + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
| 187 | + self.pos_drop = nn.Dropout(p=drop_rate) |
| 188 | + |
| 189 | + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule |
| 190 | + self.blocks = nn.ModuleList([ |
| 191 | + Block( |
| 192 | + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 193 | + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) |
| 194 | + for i in range(depth)]) |
| 195 | + self.norm = norm_layer(embed_dim) |
| 196 | + |
| 197 | + # Classifier head |
| 198 | + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
| 199 | + |
| 200 | + trunc_normal_(self.pos_embed, std=.02) |
| 201 | + trunc_normal_(self.cls_token, std=.02) |
| 202 | + self.apply(self._init_weights) |
| 203 | + |
| 204 | + def _init_weights(self, m): |
| 205 | + if isinstance(m, nn.Linear): |
| 206 | + trunc_normal_(m.weight, std=.02) |
| 207 | + if isinstance(m, nn.Linear) and m.bias is not None: |
| 208 | + nn.init.constant_(m.bias, 0) |
| 209 | + elif isinstance(m, nn.LayerNorm): |
| 210 | + nn.init.constant_(m.bias, 0) |
| 211 | + nn.init.constant_(m.weight, 1.0) |
| 212 | + |
| 213 | + def interpolate_pos_encoding(self, x, w, h): |
| 214 | + npatch = x.shape[1] - 1 |
| 215 | + N = self.pos_embed.shape[1] - 1 |
| 216 | + if npatch == N and w == h: |
| 217 | + return self.pos_embed |
| 218 | + class_pos_embed = self.pos_embed[:, 0] |
| 219 | + patch_pos_embed = self.pos_embed[:, 1:] |
| 220 | + dim = x.shape[-1] |
| 221 | + w0 = w // self.patch_embed.patch_size |
| 222 | + h0 = h // self.patch_embed.patch_size |
| 223 | + # we add a small number to avoid floating point error in the interpolation |
| 224 | + # see discussion at https://github.com/facebookresearch/dino/issues/8 |
| 225 | + w0, h0 = w0 + 0.1, h0 + 0.1 |
| 226 | + patch_pos_embed = nn.functional.interpolate( |
| 227 | + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), |
| 228 | + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), |
| 229 | + mode='bicubic', |
| 230 | + ) |
| 231 | + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] |
| 232 | + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
| 233 | + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) |
| 234 | + |
| 235 | + def prepare_tokens(self, x): |
| 236 | + B, nc, w, h = x.shape |
| 237 | + x = self.patch_embed(x) # patch linear embedding |
| 238 | + |
| 239 | + # add the [CLS] token to the embed patch tokens |
| 240 | + cls_tokens = self.cls_token.expand(B, -1, -1) |
| 241 | + x = torch.cat((cls_tokens, x), dim=1) |
| 242 | + |
| 243 | + # add positional encoding to each token |
| 244 | + x = x + self.interpolate_pos_encoding(x, w, h) |
| 245 | + |
| 246 | + return self.pos_drop(x) |
| 247 | + |
| 248 | + def forward(self, x): |
| 249 | + x = self.prepare_tokens(x) |
| 250 | + for blk in self.blocks: |
| 251 | + x = blk(x) |
| 252 | + x = self.norm(x) |
| 253 | + return x[:, 0] |
| 254 | + |
| 255 | + def get_last_selfattention(self, x): |
| 256 | + x = self.prepare_tokens(x) |
| 257 | + for i, blk in enumerate(self.blocks): |
| 258 | + if i < len(self.blocks) - 1: |
| 259 | + x = blk(x) |
| 260 | + else: |
| 261 | + # return attention of the last block |
| 262 | + return blk(x, return_attention=True) |
| 263 | + |
| 264 | + def get_intermediate_layers(self, x, n=1): |
| 265 | + x = self.prepare_tokens(x) |
| 266 | + # we return the output tokens from the `n` last blocks |
| 267 | + output = [] |
| 268 | + for i, blk in enumerate(self.blocks): |
| 269 | + x = blk(x) |
| 270 | + if len(self.blocks) - i <= n: |
| 271 | + output.append(self.norm(x)) |
| 272 | + return output |
| 273 | + |
| 274 | + |
| 275 | + |
| 276 | +def vit_small(patch_size=16, **kwargs): |
| 277 | + model = VisionTransformer( |
| 278 | + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, |
| 279 | + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| 280 | + return model |
| 281 | + |
| 282 | + |
| 283 | +def vit_base(patch_size=16, **kwargs): |
| 284 | + model = VisionTransformer( |
| 285 | + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, |
| 286 | + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| 287 | + return model |
| 288 | + |
| 289 | + |
| 290 | + |
| 291 | + |
| 292 | +class ViTFeat(nn.Module): |
| 293 | + """ Vision Transformer """ |
| 294 | + def __init__(self, pretrained_pth, feat_dim, vit_arch = 'base', vit_feat = 'k', patch_size=16): |
| 295 | + super().__init__() |
| 296 | + if vit_arch == 'base' : |
| 297 | + self.model = vit_base(patch_size=patch_size, num_classes=0) |
| 298 | + |
| 299 | + else : |
| 300 | + self.model = vit_small(patch_size=patch_size, num_classes=0) |
| 301 | + |
| 302 | + self.feat_dim = feat_dim |
| 303 | + self.vit_feat = vit_feat |
| 304 | + self.patch_size = patch_size |
| 305 | + |
| 306 | +# state_dict = torch.load(pretrained_pth, map_location="cpu") |
| 307 | + state_dict = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com"+pretrained_pth) |
| 308 | + self.model.load_state_dict(state_dict, strict=True) |
| 309 | + print('Loading weight from {}'.format(pretrained_pth)) |
| 310 | + |
| 311 | + |
| 312 | + def forward(self, img) : |
| 313 | + feat_out = {} |
| 314 | + def hook_fn_forward_qkv(module, input, output): |
| 315 | + feat_out["qkv"] = output |
| 316 | + |
| 317 | + self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) |
| 318 | + |
| 319 | + |
| 320 | + # Forward pass in the model |
| 321 | + with torch.no_grad() : |
| 322 | + h, w = img.shape[2], img.shape[3] |
| 323 | + feat_h, feat_w = h // self.patch_size, w // self.patch_size |
| 324 | + attentions = self.model.get_last_selfattention(img) |
| 325 | + bs, nb_head, nb_token = attentions.shape[0], attentions.shape[1], attentions.shape[2] |
| 326 | + qkv = ( |
| 327 | + feat_out["qkv"] |
| 328 | + .reshape(bs, nb_token, 3, nb_head, -1) |
| 329 | + .permute(2, 0, 3, 1, 4) |
| 330 | + ) |
| 331 | + q, k, v = qkv[0], qkv[1], qkv[2] |
| 332 | + |
| 333 | + k = k.transpose(1, 2).reshape(bs, nb_token, -1) |
| 334 | + q = q.transpose(1, 2).reshape(bs, nb_token, -1) |
| 335 | + v = v.transpose(1, 2).reshape(bs, nb_token, -1) |
| 336 | + |
| 337 | + # Modality selection |
| 338 | + if self.vit_feat == "k": |
| 339 | + feats = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) |
| 340 | + elif self.vit_feat == "q": |
| 341 | + feats = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) |
| 342 | + elif self.vit_feat == "v": |
| 343 | + feats = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) |
| 344 | + elif self.vit_feat == "kqv": |
| 345 | + k = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) |
| 346 | + q = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) |
| 347 | + v = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) |
| 348 | + feats = torch.cat([k, q, v], dim=1) |
| 349 | + return feats |
| 350 | + |
| 351 | + |
| 352 | +if __name__ == "__main__": |
| 353 | + vit_arch = 'base' |
| 354 | + vit_feat = 'k' |
| 355 | + |
| 356 | + model = ViTFeat(vit_arch, vit_feat) |
| 357 | + img = torch.cuda.FloatTensor(4, 3, 224, 224) |
| 358 | + model.cuda() |
| 359 | + # Forward pass in the model |
| 360 | + feat = model(img) |
| 361 | + print (feat.shape) |
0 commit comments