Skip to content

Commit 17e7b94

Browse files
authoredJul 27, 2023
first release
1 parent 8674278 commit 17e7b94

File tree

16 files changed

+2755
-0
lines changed

16 files changed

+2755
-0
lines changed
 

‎common/evaluation.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
r""" Evaluate mask prediction """
2+
import torch
3+
4+
5+
class Evaluator:
6+
r""" Computes intersection and union between prediction and ground-truth """
7+
@classmethod
8+
def initialize(cls):
9+
cls.ignore_index = 255
10+
11+
@classmethod
12+
def classify_prediction(cls, pred_mask, batch):
13+
gt_mask = batch.get('query_mask')
14+
15+
# Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
16+
query_ignore_idx = batch.get('query_ignore_idx')
17+
if query_ignore_idx is not None:
18+
assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
19+
query_ignore_idx *= cls.ignore_index
20+
gt_mask = gt_mask + query_ignore_idx
21+
pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index
22+
23+
# compute intersection and union of each episode in a batch
24+
area_inter, area_pred, area_gt = [], [], []
25+
for _pred_mask, _gt_mask in zip(pred_mask, gt_mask):
26+
_inter = _pred_mask[_pred_mask == _gt_mask]
27+
if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1)
28+
_area_inter = torch.tensor([0, 0], device=_pred_mask.device)
29+
else:
30+
_area_inter = torch.histc(_inter, bins=2, min=0, max=1)
31+
area_inter.append(_area_inter)
32+
area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
33+
area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
34+
area_inter = torch.stack(area_inter).t()
35+
area_pred = torch.stack(area_pred).t()
36+
area_gt = torch.stack(area_gt).t()
37+
area_union = area_pred + area_gt - area_inter
38+
39+
return area_inter, area_union

‎common/logger.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
r""" Logging during training/testing """
2+
import datetime
3+
import logging
4+
import os
5+
6+
from tensorboardX import SummaryWriter
7+
import torch
8+
9+
10+
class AverageMeter:
11+
r""" Stores loss, evaluation results """
12+
def __init__(self, dataset):
13+
self.benchmark = dataset.benchmark
14+
self.class_ids_interest = dataset.class_ids
15+
self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()
16+
17+
if self.benchmark == 'pascal':
18+
self.nclass = 20
19+
elif self.benchmark == 'coco':
20+
self.nclass = 80
21+
elif self.benchmark == 'fss':
22+
self.nclass = 1000
23+
elif self.benchmark == 'ph2':
24+
self.nclass = 1
25+
self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
26+
self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
27+
self.ones = torch.ones_like(self.union_buf)
28+
self.loss_buf = []
29+
30+
def update(self, inter_b, union_b, class_id, loss):
31+
self.intersection_buf.index_add_(1, class_id, inter_b.float())
32+
self.union_buf.index_add_(1, class_id, union_b.float())
33+
if loss is None:
34+
loss = torch.tensor(0.0)
35+
self.loss_buf.append(loss)
36+
37+
def compute_iou(self):
38+
iou = self.intersection_buf.float() / \
39+
torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
40+
iou = iou.index_select(1, self.class_ids_interest)
41+
miou = iou[1].mean() * 100
42+
43+
fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
44+
self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
45+
46+
return miou, fb_iou
47+
48+
def write_result(self, split, epoch):
49+
iou, fb_iou = self.compute_iou()
50+
51+
loss_buf = torch.stack(self.loss_buf)
52+
msg = '\n*** %s ' % split
53+
msg += '[@Epoch %02d] ' % epoch
54+
msg += 'Avg L: %6.5f ' % loss_buf.mean()
55+
msg += 'mIoU: %5.2f ' % iou
56+
msg += 'FB-IoU: %5.2f ' % fb_iou
57+
58+
msg += '***\n'
59+
Logger.info(msg)
60+
61+
def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
62+
if batch_idx % write_batch_idx == 0:
63+
msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
64+
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
65+
iou, fb_iou = self.compute_iou()
66+
if epoch != -1:
67+
loss_buf = torch.stack(self.loss_buf)
68+
msg += 'L: %6.5f ' % loss_buf[-1]
69+
msg += 'Avg L: %6.5f ' % loss_buf.mean()
70+
msg += 'mIoU: %5.2f | ' % iou
71+
msg += 'FB-IoU: %5.2f' % fb_iou
72+
Logger.info(msg)
73+
74+
75+
class Logger:
76+
r""" Writes evaluation results of training/testing """
77+
@classmethod
78+
def initialize(cls, args, training):
79+
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
80+
logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-2].split('.')[0] + logtime
81+
if logpath == '': logpath = logtime
82+
83+
cls.logpath = os.path.join('logs', logpath + '.log')
84+
cls.benchmark = args.benchmark
85+
os.makedirs(cls.logpath)
86+
87+
logging.basicConfig(filemode='w',
88+
filename=os.path.join(cls.logpath, 'log.txt'),
89+
level=logging.INFO,
90+
format='%(message)s',
91+
datefmt='%m-%d %H:%M:%S')
92+
93+
# Console log config
94+
console = logging.StreamHandler()
95+
console.setLevel(logging.INFO)
96+
formatter = logging.Formatter('%(message)s')
97+
console.setFormatter(formatter)
98+
logging.getLogger('').addHandler(console)
99+
100+
# Tensorboard writer
101+
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
102+
103+
# Log arguments
104+
logging.info('\n:=========== Few-shot Seg. with HSNet ===========')
105+
for arg_key in args.__dict__:
106+
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
107+
logging.info(':================================================\n')
108+
109+
@classmethod
110+
def info(cls, msg):
111+
r""" Writes log message to log.txt """
112+
logging.info(msg)
113+
114+
@classmethod
115+
def save_model_miou(cls, model, epoch, val_miou):
116+
torch.save(model.state_dict(), os.path.join(cls.logpath, 'best_model.pt'))
117+
cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
118+
119+
@classmethod
120+
def log_params(cls, model):
121+
backbone_param = 0
122+
learner_param = 0
123+
for k in model.state_dict().keys():
124+
n_param = model.state_dict()[k].view(-1).size(0)
125+
if k.split('.')[0] in 'backbone':
126+
if k.split('.')[1] in ['classifier', 'fc']: # as fc layers are not used in HSNet
127+
continue
128+
backbone_param += n_param
129+
else:
130+
learner_param += n_param
131+
Logger.info('Backbone # param.: %d' % backbone_param)
132+
Logger.info('Learnable # param.: %d' % learner_param)
133+
Logger.info('Total # param.: %d' % (backbone_param + learner_param))
134+

‎common/utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
r""" Helper functions """
2+
import random
3+
4+
import torch
5+
import numpy as np
6+
7+
8+
def fix_randseed(seed):
9+
r""" Set random seeds for reproducibility """
10+
if seed is None:
11+
seed = int(random.random() * 1e5)
12+
np.random.seed(seed)
13+
torch.manual_seed(seed)
14+
torch.cuda.manual_seed(seed)
15+
torch.cuda.manual_seed_all(seed)
16+
torch.backends.cudnn.benchmark = False
17+
torch.backends.cudnn.deterministic = True
18+
19+
20+
def mean(x):
21+
return sum(x) / len(x) if len(x) > 0 else 0.0
22+
23+
24+
def to_cuda(batch):
25+
for key, value in batch.items():
26+
if isinstance(value, torch.Tensor):
27+
batch[key] = value.cuda()
28+
return batch
29+
30+
31+
def to_cpu(tensor):
32+
return tensor.detach().clone().cpu()

‎common/vis.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
r""" Visualize model predictions """
2+
import os
3+
4+
from PIL import Image
5+
import numpy as np
6+
import torchvision.transforms as transforms
7+
8+
from . import utils
9+
10+
11+
class Visualizer:
12+
13+
@classmethod
14+
def initialize(cls, visualize):
15+
cls.visualize = visualize
16+
if not visualize:
17+
return
18+
19+
cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)}
20+
for key, value in cls.colors.items():
21+
cls.colors[key] = tuple([c / 255 for c in cls.colors[key]])
22+
23+
cls.mean_img = [0.485, 0.456, 0.406]
24+
cls.std_img = [0.229, 0.224, 0.225]
25+
cls.to_pil = transforms.ToPILImage()
26+
cls.vis_path = './vis/'
27+
if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path)
28+
29+
@classmethod
30+
def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None):
31+
spt_img_b = utils.to_cpu(spt_img_b)
32+
spt_mask_b = utils.to_cpu(spt_mask_b)
33+
qry_img_b = utils.to_cpu(qry_img_b)
34+
qry_mask_b = utils.to_cpu(qry_mask_b)
35+
pred_mask_b = utils.to_cpu(pred_mask_b)
36+
cls_id_b = utils.to_cpu(cls_id_b)
37+
38+
for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \
39+
enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)):
40+
iou = iou_b[sample_idx] if iou_b is not None else None
41+
cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou)
42+
43+
@classmethod
44+
def to_numpy(cls, tensor, type):
45+
if type == 'img':
46+
return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8)
47+
elif type == 'mask':
48+
return np.array(tensor).astype(np.uint8)
49+
else:
50+
raise Exception('Undefined tensor type: %s' % type)
51+
52+
@classmethod
53+
def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None):
54+
55+
spt_color = cls.colors['blue']
56+
qry_color = cls.colors['red']
57+
pred_color = cls.colors['red']
58+
59+
spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs]
60+
spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs]
61+
spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks]
62+
spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)]
63+
64+
qry_img = cls.to_numpy(qry_img, 'img')
65+
qry_pil = cls.to_pil(qry_img)
66+
qry_mask = cls.to_numpy(qry_mask, 'mask')
67+
pred_mask = cls.to_numpy(pred_mask, 'mask')
68+
pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color))
69+
qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color))
70+
71+
merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil])
72+
73+
iou = iou.item() if iou else 0.0
74+
merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg')
75+
76+
@classmethod
77+
def merge_image_pair(cls, pil_imgs):
78+
r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """
79+
80+
canvas_width = sum([pil.size[0] for pil in pil_imgs])
81+
canvas_height = max([pil.size[1] for pil in pil_imgs])
82+
canvas = Image.new('RGB', (canvas_width, canvas_height))
83+
84+
xpos = 0
85+
for pil in pil_imgs:
86+
canvas.paste(pil, (xpos, 0))
87+
xpos += pil.size[0]
88+
89+
return canvas
90+
91+
@classmethod
92+
def apply_mask(cls, image, mask, color, alpha=0.5):
93+
r""" Apply mask to the given image. """
94+
for c in range(3):
95+
image[:, :, c] = np.where(mask == 1,
96+
image[:, :, c] *
97+
(1 - alpha) + alpha * color[c] * 255,
98+
image[:, :, c])
99+
return image
100+
101+
@classmethod
102+
def unnormalize(cls, img):
103+
img = img.clone()
104+
for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img):
105+
im_channel.mul_(std).add_(mean)
106+
return img

‎create_mask/creating_mask.ipynb

+348
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import sys\n",
10+
"sys.path.append('./model')\n",
11+
"import dino # model\n",
12+
"import argparse\n",
13+
"import utils\n",
14+
"import os\n",
15+
"\n",
16+
"import PIL.Image as Image\n",
17+
"import cv2\n",
18+
"import numpy as np\n",
19+
"from tqdm import tqdm\n",
20+
"\n",
21+
"from torchvision import transforms\n",
22+
"\n",
23+
"import matplotlib.pyplot as plt\n",
24+
"%matplotlib inline\n",
25+
"\n",
26+
"import torch\n",
27+
"import torch.nn.functional as F\n",
28+
"import numpy as np\n",
29+
"from scipy.linalg import eigh\n",
30+
"from scipy import ndimage\n",
31+
"import torch\n",
32+
"import torch.nn.functional as F\n",
33+
"import numpy as np\n",
34+
"import glob\n"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 2,
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"def ncut(feats, dims, scales, init_image_size, tau = 0, eps=1e-5, im_name='', no_binary_graph=False):\n",
44+
" \"\"\"\n",
45+
" Implementation of NCut Method.\n",
46+
" Inputs\n",
47+
" feats: the pixel/patche features of an image\n",
48+
" dims: dimension of the map from which the features are used\n",
49+
" scales: from image to map scale\n",
50+
" init_image_size: size of the image\n",
51+
" tau: thresold for graph construction\n",
52+
" eps: graph edge weight\n",
53+
" im_name: image_name\n",
54+
" no_binary_graph: ablation study for using similarity score as graph edge weight\n",
55+
" \"\"\"\n",
56+
" feats = F.normalize(feats, p=2, dim=0)\n",
57+
" A = (feats.transpose(0,1) @ feats)\n",
58+
" A = A.cpu().numpy()\n",
59+
" if no_binary_graph:\n",
60+
" A[A<tau] = eps\n",
61+
" else:\n",
62+
" A = A > tau\n",
63+
" A = np.where(A.astype(float) == 0, eps, A)\n",
64+
" d_i = np.sum(A, axis=1)\n",
65+
" D = np.diag(d_i)\n",
66+
"\n",
67+
" # Print second and third smallest eigenvector\n",
68+
" _, eigenvectors = eigh(D-A, D, subset_by_index=[1,2])\n",
69+
" eigenvec = np.copy(eigenvectors[:, 0])\n",
70+
"\n",
71+
"\n",
72+
" # method1 avg\n",
73+
" second_smallest_vec = eigenvectors[:, 0]\n",
74+
" avg = np.sum(second_smallest_vec) / len(second_smallest_vec)\n",
75+
" bipartition = second_smallest_vec > avg\n",
76+
"\n",
77+
" seed = np.argmax(np.abs(second_smallest_vec))\n",
78+
"\n",
79+
" if bipartition[seed] != 1:\n",
80+
" eigenvec = eigenvec * -1\n",
81+
" bipartition = np.logical_not(bipartition)\n",
82+
" bipartition = bipartition.reshape(dims).astype(float)\n",
83+
"\n",
84+
" # predict BBox\n",
85+
" pred, _, objects,cc = detect_box(bipartition, seed, dims, scales=scales, initial_im_size=init_image_size) ## We only extract the principal object BBox\n",
86+
" mask = np.zeros(dims)\n",
87+
" mask[cc[0],cc[1]] = 1\n",
88+
"\n",
89+
" mask = torch.from_numpy(mask).to('cuda')\n",
90+
"# mask = torch.from_numpy(bipartition).to('cuda')\n",
91+
" bipartition = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=init_image_size, mode='nearest').squeeze()\n",
92+
" \n",
93+
"\n",
94+
" eigvec = second_smallest_vec.reshape(dims) \n",
95+
" eigvec = torch.from_numpy(eigvec).to('cuda')\n",
96+
" eigvec = F.interpolate(eigvec.unsqueeze(0).unsqueeze(0), size=init_image_size, mode='nearest').squeeze()\n",
97+
" return seed, bipartition.cpu().numpy(), eigvec.cpu().numpy(), eigenvectors\n",
98+
"\n",
99+
"def detect_box(bipartition, seed, dims, initial_im_size=None, scales=None, principle_object=True):\n",
100+
" \"\"\"\n",
101+
" Extract a box corresponding to the seed patch. Among connected components extract from the affinity matrix, select the one corresponding to the seed patch.\n",
102+
" \"\"\"\n",
103+
" w_featmap, h_featmap = dims\n",
104+
" objects, num_objects = ndimage.label(bipartition)\n",
105+
" cc = objects[np.unravel_index(seed, dims)]\n",
106+
"\n",
107+
"\n",
108+
" if principle_object:\n",
109+
" mask = np.where(objects == cc)\n",
110+
" # Add +1 because excluded max\n",
111+
" ymin, ymax = min(mask[0]), max(mask[0]) + 1\n",
112+
" xmin, xmax = min(mask[1]), max(mask[1]) + 1\n",
113+
" # Rescale to image size\n",
114+
" r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax\n",
115+
" r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax\n",
116+
" pred = [r_xmin, r_ymin, r_xmax, r_ymax]\n",
117+
"\n",
118+
" # Check not out of image size (used when padding)\n",
119+
" if initial_im_size:\n",
120+
" pred[2] = min(pred[2], initial_im_size[1])\n",
121+
" pred[3] = min(pred[3], initial_im_size[0])\n",
122+
"\n",
123+
" # Coordinate predictions for the feature space\n",
124+
" # Axis different then in image space\n",
125+
" pred_feats = [ymin, xmin, ymax, xmax]\n",
126+
"\n",
127+
" return pred, pred_feats, objects, mask\n",
128+
" else:\n",
129+
" raise NotImplementedError\n"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": 3,
135+
"metadata": {},
136+
"outputs": [],
137+
"source": [
138+
"# Image transformation applied to all images\n",
139+
"ToTensor = transforms.Compose([\n",
140+
" transforms.ToTensor(),\n",
141+
" transforms.Normalize((0.485, 0.456, 0.406),\n",
142+
" (0.229, 0.224, 0.225)),])\n",
143+
"\n",
144+
"def get_tokencut_binary_map(img_pth, backbone,patch_size, tau) :\n",
145+
" I = Image.open(img_pth).convert('RGB')\n",
146+
" I_resize, w, h, feat_w, feat_h = utils.resize_pil(I, patch_size)\n",
147+
"\n",
148+
" tensor = ToTensor(I_resize).unsqueeze(0).cuda()\n",
149+
" feat = backbone(tensor)[0]\n",
150+
"\n",
151+
" seed, bipartition, eigvec, eigvectors = ncut(feat, [feat_h, feat_w], [patch_size, patch_size], [h,w], tau)\n",
152+
" return bipartition, eigvec, eigvectors.reshape([feat_h, feat_w, 2]).astype(float)\n",
153+
"\n",
154+
"def mask_color_compose(org, mask, mask_color = [173, 216, 230]) :\n",
155+
"\n",
156+
" mask_fg = mask > 0.5\n",
157+
" rgb = np.copy(org)\n",
158+
" rgb[mask_fg] = (rgb[mask_fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8)\n",
159+
"\n",
160+
" return Image.fromarray(rgb)"
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": 4,
166+
"metadata": {},
167+
"outputs": [
168+
{
169+
"name": "stdout",
170+
"output_type": "stream",
171+
"text": [
172+
"Namespace(out_dir='./output', vit_arch='small', vit_feat='k', patch_size=16, tau=0.2, sigma_spatial=16, sigma_luma=16, sigma_chroma=8, dataset=None, nb_vis=100, img_path='D:/deeplearning_Sanaz/TokenCut/TokenCut/examples/mydata/ab_wheel', save_feat_dir='../image')\n"
173+
]
174+
}
175+
],
176+
"source": [
177+
"parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
178+
"\n",
179+
"## input / output dir\n",
180+
"parser.add_argument('--out-dir', type=str, default = './output', help='output directory')\n",
181+
"\n",
182+
"parser.add_argument('--vit-arch', type=str, default='small', choices=['base', 'small'], help='which architecture')\n",
183+
"\n",
184+
"parser.add_argument('--vit-feat', type=str, default='k', choices=['k', 'q', 'v', 'kqv'], help='which features')\n",
185+
"\n",
186+
"parser.add_argument('--patch-size', type=int, default=16, choices=[16, 8], help='patch size')\n",
187+
"\n",
188+
"parser.add_argument('--tau', type=float, default=0.2, help='Tau for tresholding graph')\n",
189+
"\n",
190+
"parser.add_argument('--sigma-spatial', type=float, default=16, help='sigma spatial in the bilateral solver')\n",
191+
"\n",
192+
"parser.add_argument('--sigma-luma', type=float, default=16, help='sigma luma in the bilateral solver')\n",
193+
"\n",
194+
"parser.add_argument('--sigma-chroma', type=float, default=8, help='sigma chroma in the bilateral solver')\n",
195+
"\n",
196+
"\n",
197+
"parser.add_argument('--dataset', type=str, default=None, choices=['ECSSD', 'DUTS', 'DUT', None], help='which dataset?')\n",
198+
"\n",
199+
"parser.add_argument('--nb-vis', type=int, default=100, choices=[1, 200], help='nb of visualization')\n",
200+
"\n",
201+
"parser.add_argument('--img-path', type=str, default='fss-dataset/mydata/ab_wheel', help='single image visualization')\n",
202+
"parser.add_argument('--save_feat_dir',type=str, default= '../image')\n",
203+
"args = parser.parse_args(args=[])\n",
204+
"print (args)"
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": 5,
210+
"metadata": {},
211+
"outputs": [
212+
{
213+
"name": "stdout",
214+
"output_type": "stream",
215+
"text": [
216+
"Loading weight from /dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth\n"
217+
]
218+
},
219+
{
220+
"data": {
221+
"text/plain": [
222+
"ViTFeat(\n",
223+
" (model): VisionTransformer(\n",
224+
" (patch_embed): PatchEmbed(\n",
225+
" (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))\n",
226+
" )\n",
227+
" (pos_drop): Dropout(p=0.0, inplace=False)\n",
228+
" (blocks): ModuleList(\n",
229+
" (0-11): 12 x Block(\n",
230+
" (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
231+
" (attn): Attention(\n",
232+
" (qkv): Linear(in_features=384, out_features=1152, bias=True)\n",
233+
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
234+
" (proj): Linear(in_features=384, out_features=384, bias=True)\n",
235+
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
236+
" )\n",
237+
" (drop_path): Identity()\n",
238+
" (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
239+
" (mlp): Mlp(\n",
240+
" (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
241+
" (act): GELU(approximate='none')\n",
242+
" (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
243+
" (drop): Dropout(p=0.0, inplace=False)\n",
244+
" )\n",
245+
" )\n",
246+
" )\n",
247+
" (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
248+
" (head): Identity()\n",
249+
" )\n",
250+
")"
251+
]
252+
},
253+
"execution_count": 5,
254+
"metadata": {},
255+
"output_type": "execute_result"
256+
}
257+
],
258+
"source": [
259+
"## Define the network for feature extraction\n",
260+
"if args.vit_arch == 'base' and args.patch_size == 16:\n",
261+
" url = \"/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\"\n",
262+
" feat_dim = 768\n",
263+
"elif args.vit_arch == 'base' and args.patch_size == 8:\n",
264+
" url = \"/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth\"\n",
265+
" feat_dim = 768\n",
266+
"elif args.vit_arch == 'small' and args.patch_size == 16:\n",
267+
" url = \"/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth\"\n",
268+
" feat_dim = 384\n",
269+
"elif args.vit_arch == 'base' and args.patch_size == 8:\n",
270+
" url = \"/dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth\"\n",
271+
"\n",
272+
"backbone = dino.ViTFeat(url, feat_dim, args.vit_arch, args.vit_feat, args.patch_size)\n",
273+
"msg = 'Load {} pre-trained feature...'.format(args.vit_arch)\n",
274+
"# print (msg)\n",
275+
"backbone.eval()\n",
276+
"backbone.cuda()"
277+
]
278+
},
279+
{
280+
"cell_type": "code",
281+
"execution_count": 6,
282+
"metadata": {},
283+
"outputs": [],
284+
"source": [
285+
"\n",
286+
"from skimage.io import imread\n",
287+
"import cv2\n",
288+
"from PIL import Image\n",
289+
"pattern = args.img_path + \"/**/*.jpg\"\n",
290+
"# Get a list of file paths that match the pattern\n",
291+
"image_list = glob.glob(pattern, recursive=True)\n",
292+
"\n",
293+
"# Iterate over the file paths and load the images using PIL\n",
294+
"for im_path in image_list:\n",
295+
" folder_path,im_pth = os.path.split(im_path)\n",
296+
"\n",
297+
" \n",
298+
" if im_pth.endswith('.jpg'):\n",
299+
" img = Image.open(im_path)\n",
300+
" original_image = imread(im_path)\n",
301+
" \n",
302+
"\n",
303+
" im_name =os.path.basename(im_pth)\n",
304+
" im_name = im_name.split('.')[0]\n",
305+
"\n",
306+
" \n",
307+
" bipartition, eigvec, eigvectors = get_tokencut_binary_map(im_path, backbone, args.patch_size, args.tau)\n",
308+
" bipartition = bipartition*255\n",
309+
" im_jpg = Image.fromarray( bipartition)\n",
310+
" binary_mask = im_jpg.convert('RGB')\n",
311+
" binary_mask.save(os.path.join(folder_path,im_name+'_mask'+'.png'))\n",
312+
" binary_mask = np.array(binary_mask)\n",
313+
" binary_mask = binary_mask.astype(np.float32) / 255\n",
314+
" \n",
315+
" "
316+
]
317+
},
318+
{
319+
"cell_type": "code",
320+
"execution_count": null,
321+
"metadata": {},
322+
"outputs": [],
323+
"source": []
324+
}
325+
],
326+
"metadata": {
327+
"kernelspec": {
328+
"display_name": "base",
329+
"language": "python",
330+
"name": "python3"
331+
},
332+
"language_info": {
333+
"codemirror_mode": {
334+
"name": "ipython",
335+
"version": 3
336+
},
337+
"file_extension": ".py",
338+
"mimetype": "text/x-python",
339+
"name": "python",
340+
"nbconvert_exporter": "python",
341+
"pygments_lexer": "ipython3",
342+
"version": "3.10.9"
343+
},
344+
"orig_nbformat": 4
345+
},
346+
"nbformat": 4,
347+
"nbformat_minor": 2
348+
}

‎create_mask/dino.py

+361
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Copied from Dino repo. https://github.com/facebookresearch/dino
16+
Mostly copy-paste from timm library.
17+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
18+
"""
19+
import math
20+
from functools import partial
21+
22+
import torch
23+
import torch.nn as nn
24+
25+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
26+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
27+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
28+
def norm_cdf(x):
29+
# Computes standard normal cumulative distribution function
30+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
31+
32+
if (mean < a - 2 * std) or (mean > b + 2 * std):
33+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
34+
"The distribution of values may be incorrect.",
35+
stacklevel=2)
36+
37+
with torch.no_grad():
38+
# Values are generated by using a truncated uniform distribution and
39+
# then using the inverse CDF for the normal distribution.
40+
# Get upper and lower cdf values
41+
l = norm_cdf((a - mean) / std)
42+
u = norm_cdf((b - mean) / std)
43+
44+
# Uniformly fill tensor with values from [l, u], then translate to
45+
# [2l-1, 2u-1].
46+
tensor.uniform_(2 * l - 1, 2 * u - 1)
47+
48+
# Use inverse cdf transform for normal distribution to get truncated
49+
# standard normal
50+
tensor.erfinv_()
51+
52+
# Transform to proper mean, std
53+
tensor.mul_(std * math.sqrt(2.))
54+
tensor.add_(mean)
55+
56+
# Clamp to ensure it's in the proper range
57+
tensor.clamp_(min=a, max=b)
58+
return tensor
59+
60+
61+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
62+
# type: (Tensor, float, float, float, float) -> Tensor
63+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
64+
65+
66+
def drop_path(x, drop_prob: float = 0., training: bool = False):
67+
if drop_prob == 0. or not training:
68+
return x
69+
keep_prob = 1 - drop_prob
70+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
71+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
72+
random_tensor.floor_() # binarize
73+
output = x.div(keep_prob) * random_tensor
74+
return output
75+
76+
77+
class DropPath(nn.Module):
78+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
79+
"""
80+
def __init__(self, drop_prob=None):
81+
super(DropPath, self).__init__()
82+
self.drop_prob = drop_prob
83+
84+
def forward(self, x):
85+
return drop_path(x, self.drop_prob, self.training)
86+
87+
88+
class Mlp(nn.Module):
89+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
90+
super().__init__()
91+
out_features = out_features or in_features
92+
hidden_features = hidden_features or in_features
93+
self.fc1 = nn.Linear(in_features, hidden_features)
94+
self.act = act_layer()
95+
self.fc2 = nn.Linear(hidden_features, out_features)
96+
self.drop = nn.Dropout(drop)
97+
98+
def forward(self, x):
99+
x = self.fc1(x)
100+
x = self.act(x)
101+
x = self.drop(x)
102+
x = self.fc2(x)
103+
x = self.drop(x)
104+
return x
105+
106+
107+
class Attention(nn.Module):
108+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
109+
super().__init__()
110+
self.num_heads = num_heads
111+
head_dim = dim // num_heads
112+
self.scale = qk_scale or head_dim ** -0.5
113+
114+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
115+
self.attn_drop = nn.Dropout(attn_drop)
116+
self.proj = nn.Linear(dim, dim)
117+
self.proj_drop = nn.Dropout(proj_drop)
118+
119+
def forward(self, x):
120+
B, N, C = x.shape
121+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
122+
q, k, v = qkv[0], qkv[1], qkv[2]
123+
124+
attn = (q @ k.transpose(-2, -1)) * self.scale
125+
attn = attn.softmax(dim=-1)
126+
attn = self.attn_drop(attn)
127+
128+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
129+
x = self.proj(x)
130+
x = self.proj_drop(x)
131+
return x, attn
132+
133+
134+
class Block(nn.Module):
135+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
136+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
137+
super().__init__()
138+
self.norm1 = norm_layer(dim)
139+
self.attn = Attention(
140+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
141+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
142+
self.norm2 = norm_layer(dim)
143+
mlp_hidden_dim = int(dim * mlp_ratio)
144+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
145+
146+
def forward(self, x, return_attention=False):
147+
y, attn = self.attn(self.norm1(x))
148+
if return_attention:
149+
return attn
150+
x = x + self.drop_path(y)
151+
x = x + self.drop_path(self.mlp(self.norm2(x)))
152+
return x
153+
154+
155+
class PatchEmbed(nn.Module):
156+
""" Image to Patch Embedding
157+
"""
158+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
159+
super().__init__()
160+
num_patches = (img_size // patch_size) * (img_size // patch_size)
161+
self.img_size = img_size
162+
self.patch_size = patch_size
163+
self.num_patches = num_patches
164+
165+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
166+
167+
def forward(self, x):
168+
B, C, H, W = x.shape
169+
x = self.proj(x).flatten(2).transpose(1, 2)
170+
return x
171+
172+
173+
class VisionTransformer(nn.Module):
174+
""" Vision Transformer """
175+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
176+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
177+
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
178+
super().__init__()
179+
self.num_features = self.embed_dim = embed_dim
180+
181+
self.patch_embed = PatchEmbed(
182+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
183+
num_patches = self.patch_embed.num_patches
184+
185+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
186+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
187+
self.pos_drop = nn.Dropout(p=drop_rate)
188+
189+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
190+
self.blocks = nn.ModuleList([
191+
Block(
192+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
193+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
194+
for i in range(depth)])
195+
self.norm = norm_layer(embed_dim)
196+
197+
# Classifier head
198+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
199+
200+
trunc_normal_(self.pos_embed, std=.02)
201+
trunc_normal_(self.cls_token, std=.02)
202+
self.apply(self._init_weights)
203+
204+
def _init_weights(self, m):
205+
if isinstance(m, nn.Linear):
206+
trunc_normal_(m.weight, std=.02)
207+
if isinstance(m, nn.Linear) and m.bias is not None:
208+
nn.init.constant_(m.bias, 0)
209+
elif isinstance(m, nn.LayerNorm):
210+
nn.init.constant_(m.bias, 0)
211+
nn.init.constant_(m.weight, 1.0)
212+
213+
def interpolate_pos_encoding(self, x, w, h):
214+
npatch = x.shape[1] - 1
215+
N = self.pos_embed.shape[1] - 1
216+
if npatch == N and w == h:
217+
return self.pos_embed
218+
class_pos_embed = self.pos_embed[:, 0]
219+
patch_pos_embed = self.pos_embed[:, 1:]
220+
dim = x.shape[-1]
221+
w0 = w // self.patch_embed.patch_size
222+
h0 = h // self.patch_embed.patch_size
223+
# we add a small number to avoid floating point error in the interpolation
224+
# see discussion at https://github.com/facebookresearch/dino/issues/8
225+
w0, h0 = w0 + 0.1, h0 + 0.1
226+
patch_pos_embed = nn.functional.interpolate(
227+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
228+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
229+
mode='bicubic',
230+
)
231+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
232+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
233+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
234+
235+
def prepare_tokens(self, x):
236+
B, nc, w, h = x.shape
237+
x = self.patch_embed(x) # patch linear embedding
238+
239+
# add the [CLS] token to the embed patch tokens
240+
cls_tokens = self.cls_token.expand(B, -1, -1)
241+
x = torch.cat((cls_tokens, x), dim=1)
242+
243+
# add positional encoding to each token
244+
x = x + self.interpolate_pos_encoding(x, w, h)
245+
246+
return self.pos_drop(x)
247+
248+
def forward(self, x):
249+
x = self.prepare_tokens(x)
250+
for blk in self.blocks:
251+
x = blk(x)
252+
x = self.norm(x)
253+
return x[:, 0]
254+
255+
def get_last_selfattention(self, x):
256+
x = self.prepare_tokens(x)
257+
for i, blk in enumerate(self.blocks):
258+
if i < len(self.blocks) - 1:
259+
x = blk(x)
260+
else:
261+
# return attention of the last block
262+
return blk(x, return_attention=True)
263+
264+
def get_intermediate_layers(self, x, n=1):
265+
x = self.prepare_tokens(x)
266+
# we return the output tokens from the `n` last blocks
267+
output = []
268+
for i, blk in enumerate(self.blocks):
269+
x = blk(x)
270+
if len(self.blocks) - i <= n:
271+
output.append(self.norm(x))
272+
return output
273+
274+
275+
276+
def vit_small(patch_size=16, **kwargs):
277+
model = VisionTransformer(
278+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
279+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
280+
return model
281+
282+
283+
def vit_base(patch_size=16, **kwargs):
284+
model = VisionTransformer(
285+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
286+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
287+
return model
288+
289+
290+
291+
292+
class ViTFeat(nn.Module):
293+
""" Vision Transformer """
294+
def __init__(self, pretrained_pth, feat_dim, vit_arch = 'base', vit_feat = 'k', patch_size=16):
295+
super().__init__()
296+
if vit_arch == 'base' :
297+
self.model = vit_base(patch_size=patch_size, num_classes=0)
298+
299+
else :
300+
self.model = vit_small(patch_size=patch_size, num_classes=0)
301+
302+
self.feat_dim = feat_dim
303+
self.vit_feat = vit_feat
304+
self.patch_size = patch_size
305+
306+
# state_dict = torch.load(pretrained_pth, map_location="cpu")
307+
state_dict = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com"+pretrained_pth)
308+
self.model.load_state_dict(state_dict, strict=True)
309+
print('Loading weight from {}'.format(pretrained_pth))
310+
311+
312+
def forward(self, img) :
313+
feat_out = {}
314+
def hook_fn_forward_qkv(module, input, output):
315+
feat_out["qkv"] = output
316+
317+
self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv)
318+
319+
320+
# Forward pass in the model
321+
with torch.no_grad() :
322+
h, w = img.shape[2], img.shape[3]
323+
feat_h, feat_w = h // self.patch_size, w // self.patch_size
324+
attentions = self.model.get_last_selfattention(img)
325+
bs, nb_head, nb_token = attentions.shape[0], attentions.shape[1], attentions.shape[2]
326+
qkv = (
327+
feat_out["qkv"]
328+
.reshape(bs, nb_token, 3, nb_head, -1)
329+
.permute(2, 0, 3, 1, 4)
330+
)
331+
q, k, v = qkv[0], qkv[1], qkv[2]
332+
333+
k = k.transpose(1, 2).reshape(bs, nb_token, -1)
334+
q = q.transpose(1, 2).reshape(bs, nb_token, -1)
335+
v = v.transpose(1, 2).reshape(bs, nb_token, -1)
336+
337+
# Modality selection
338+
if self.vit_feat == "k":
339+
feats = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w)
340+
elif self.vit_feat == "q":
341+
feats = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w)
342+
elif self.vit_feat == "v":
343+
feats = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w)
344+
elif self.vit_feat == "kqv":
345+
k = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w)
346+
q = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w)
347+
v = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w)
348+
feats = torch.cat([k, q, v], dim=1)
349+
return feats
350+
351+
352+
if __name__ == "__main__":
353+
vit_arch = 'base'
354+
vit_feat = 'k'
355+
356+
model = ViTFeat(vit_arch, vit_feat)
357+
img = torch.cuda.FloatTensor(4, 3, 224, 224)
358+
model.cuda()
359+
# Forward pass in the model
360+
feat = model(img)
361+
print (feat.shape)

‎create_mask/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import PIL.Image as Image
2+
3+
def resize_pil(I, patch_size=16) :
4+
w, h = I.size
5+
6+
new_w, new_h = int(round(w / patch_size)) * patch_size, int(round(h / patch_size)) * patch_size
7+
feat_w, feat_h = new_w // patch_size, new_h // patch_size
8+
9+
return I.resize((new_w, new_h), resample=Image.LANCZOS), w, h, feat_w, feat_h

‎data/dataset.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
r""" Dataloader builder for few-shot semantic segmentation dataset """
2+
from torchvision import transforms
3+
from torch.utils.data import DataLoader
4+
from data.fss import DatasetFSS
5+
6+
class FSSDataset:
7+
8+
@classmethod
9+
def initialize(cls, img_size, datapath, use_original_imgsize):
10+
11+
cls.datasets = {
12+
'fss': DatasetFSS,
13+
}
14+
15+
# FSS
16+
cls.img_mean = [0.485, 0.456, 0.406]
17+
cls.img_std = [0.229, 0.224, 0.225]
18+
19+
20+
cls.datapath = datapath
21+
cls.use_original_imgsize = use_original_imgsize
22+
23+
cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
24+
transforms.ToTensor()])
25+
26+
@classmethod
27+
def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1):
28+
shuffle = split == 'trn'
29+
nworker = nworker if split == 'trn' else 0
30+
31+
dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot, use_original_imgsize=cls.use_original_imgsize)
32+
dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker)
33+
34+
return dataloader

‎data/fss.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
r""" FSS-1000 few-shot semantic segmentation dataset """
2+
import os
3+
import glob
4+
5+
from torch.utils.data import Dataset
6+
import torch.nn.functional as F
7+
import torch
8+
import PIL.Image as Image
9+
import numpy as np
10+
11+
12+
class DatasetFSS(Dataset):
13+
def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize):
14+
self.split = split
15+
self.benchmark = 'fss'
16+
self.shot = shot
17+
18+
self.base_path = os.path.join(datapath, 'FSS-1000')
19+
20+
# Given predefined test split, load randomly generated training/val splits:
21+
# (reference regarding trn/val/test splits: https://github.com/HKUSTCV/FSS-1000/issues/7))
22+
with open('./data/splits/fss/%s.txt' % split, 'r') as f:
23+
self.categories = f.read().split('\n')[:-1]
24+
self.categories = sorted(self.categories)
25+
26+
self.class_ids = self.build_class_ids()
27+
self.img_metadata = self.build_img_metadata()
28+
29+
self.transform = transform
30+
31+
def __len__(self):
32+
return len(self.img_metadata)
33+
34+
def __getitem__(self, idx):
35+
36+
37+
query_name, support_names, class_sample = self.sample_episode(idx)
38+
39+
query_img, query_mask, support_imgs, support_masks = self.load_frame(query_name, support_names)
40+
41+
query_img = self.transform(query_img)
42+
query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
43+
44+
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
45+
46+
support_masks_tmp = []
47+
for smask in support_masks:
48+
smask = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
49+
support_masks_tmp.append(smask)
50+
support_masks = torch.stack(support_masks_tmp)
51+
52+
batch = {'query_img': query_img,
53+
'query_mask': query_mask,
54+
'query_name': query_name,
55+
56+
'support_imgs': support_imgs,
57+
'support_masks': support_masks,
58+
'support_names': support_names,
59+
60+
'class_id': torch.tensor(class_sample)}
61+
62+
return batch
63+
64+
def load_frame(self, query_name, support_names):
65+
query_img = Image.open(query_name).convert('RGB')
66+
support_imgs = [Image.open(name).convert('RGB') for name in support_names]
67+
68+
query_id = query_name.split('/')[-1].split('.')[0]
69+
query_name = os.path.join(os.path.dirname(query_name), query_id) + '.png'
70+
support_ids = [name.split('/')[-1].split('.')[0] for name in support_names]
71+
if self.split=="test":
72+
support_names = [os.path.join(os.path.dirname(name), sid) + '_mask.png' for name, sid in zip(support_names, support_ids)]
73+
# _mask.png
74+
else:
75+
support_names = [os.path.join(os.path.dirname(name), sid) + '.png' for name, sid in zip(support_names, support_ids)]
76+
77+
query_mask = self.read_mask(query_name)
78+
support_masks = [self.read_mask(name) for name in support_names]
79+
80+
return query_img, query_mask, support_imgs, support_masks
81+
82+
def read_mask(self, img_name):
83+
mask = torch.tensor(np.array(Image.open(img_name).convert('L')))
84+
mask[mask < 128] = 0
85+
mask[mask >= 128] = 1
86+
return mask
87+
88+
def sample_episode(self, idx):
89+
query_name = self.img_metadata[idx]
90+
91+
class_sample = self.categories.index(query_name.split('/')[-2])
92+
if self.split == 'val':
93+
class_sample += 520
94+
elif self.split == 'test':
95+
class_sample += 760
96+
97+
support_names = []
98+
while True: # keep sampling support set if query == support
99+
support_name = np.random.choice(range(1, 11), 1, replace=False)[0]
100+
support_name = os.path.join(os.path.dirname(query_name), str(support_name)) + '.jpg'
101+
support_name = support_name.replace("\\", "/")
102+
if query_name != support_name: support_names.append(support_name)
103+
if len(support_names) == self.shot: break
104+
105+
return query_name, support_names, class_sample
106+
107+
def build_class_ids(self):
108+
if self.split == 'trn':
109+
class_ids = range(0, 520)
110+
elif self.split == 'val':
111+
class_ids = range(520, 760)
112+
elif self.split == 'test':
113+
class_ids = range(760, 1000)
114+
return class_ids
115+
116+
def build_img_metadata(self):
117+
img_metadata = []
118+
for cat in self.categories:
119+
img_paths = sorted([path for path in glob.glob('%s/*' % os.path.join(self.base_path, cat))])
120+
## Correct address for windows
121+
temp = []
122+
for path in img_paths:
123+
temp.append(path.replace("\\", "/"))
124+
img_paths = temp
125+
for img_path in img_paths:
126+
if os.path.basename(img_path).split('.')[1] == 'jpg':
127+
img_metadata.append(img_path)
128+
return img_metadata

‎data/splits/fss/test.txt

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
bus
2+
hotel_slipper
3+
burj_al
4+
reflex_camera
5+
abe's_flyingfish
6+
oiltank_car
7+
doormat
8+
fish_eagle
9+
barber_shaver
10+
motorbike
11+
feather_clothes
12+
wandering_albatross
13+
rice_cooker
14+
delta_wing
15+
fish
16+
nintendo_switch
17+
bustard
18+
diver
19+
minicooper
20+
cathedrale_paris
21+
big_ben
22+
combination_lock
23+
villa_savoye
24+
american_alligator
25+
gym_ball
26+
andean_condor
27+
leggings
28+
pyramid_cube
29+
jet_aircraft
30+
meatloaf
31+
reel
32+
swan
33+
osprey
34+
crt_screen
35+
microscope
36+
rubber_eraser
37+
arrow
38+
monkey
39+
mitten
40+
spiderman
41+
parthenon
42+
bat
43+
chess_king
44+
sulphur_butterfly
45+
quail_egg
46+
oriole
47+
iron_man
48+
wooden_boat
49+
anise
50+
steering_wheel
51+
groenendael
52+
dwarf_beans
53+
pteropus
54+
chalk_brush
55+
bloodhound
56+
moon
57+
english_foxhound
58+
boxing_gloves
59+
peregine_falcon
60+
pyraminx
61+
cicada
62+
screw
63+
shower_curtain
64+
tredmill
65+
bulb
66+
bell_pepper
67+
lemur_catta
68+
doughnut
69+
twin_tower
70+
astronaut
71+
nintendo_3ds
72+
fennel_bulb
73+
indri
74+
captain_america_shield
75+
kunai
76+
broom
77+
iphone
78+
earphone1
79+
flying_squirrel
80+
onion
81+
vinyl
82+
sydney_opera_house
83+
oyster
84+
harmonica
85+
egg
86+
breast_pump
87+
guitar
88+
potato_chips
89+
tunnel
90+
cuckoo
91+
rubick_cube
92+
plastic_bag
93+
phonograph
94+
net_surface_shoes
95+
goldfinch
96+
ipad
97+
mite_predator
98+
coffee_mug
99+
golden_plover
100+
f1_racing
101+
lapwing
102+
nintendo_gba
103+
pizza
104+
rally_car
105+
drilling_platform
106+
cd
107+
fly
108+
magpie_bird
109+
leaf_fan
110+
little_blue_heron
111+
carriage
112+
moist_proof_pad
113+
flying_snakes
114+
dart_target
115+
warehouse_tray
116+
nintendo_wiiu
117+
chiffon_cake
118+
bath_ball
119+
manatee
120+
cloud
121+
marimba
122+
eagle
123+
ruler
124+
soymilk_machine
125+
sled
126+
seagull
127+
glider_flyingfish
128+
doublebus
129+
transport_helicopter
130+
window_screen
131+
truss_bridge
132+
wasp
133+
snowman
134+
poached_egg
135+
strawberry
136+
spinach
137+
earphone2
138+
downy_pitch
139+
taj_mahal
140+
rocking_chair
141+
cablestayed_bridge
142+
sealion
143+
banana_boat
144+
pheasant
145+
stone_lion
146+
electronic_stove
147+
fox
148+
iguana
149+
rugby_ball
150+
hang_glider
151+
water_buffalo
152+
lotus
153+
paper_plane
154+
missile
155+
flamingo
156+
american_chamelon
157+
kart
158+
chinese_knot
159+
cabbage_butterfly
160+
key
161+
church
162+
tiltrotor
163+
helicopter
164+
french_fries
165+
water_heater
166+
snow_leopard
167+
goblet
168+
fan
169+
snowplow
170+
leafhopper
171+
pspgo
172+
black_bear
173+
quail
174+
condor
175+
chandelier
176+
hair_razor
177+
white_wolf
178+
toaster
179+
pidan
180+
pyramid
181+
chicken_leg
182+
letter_opener
183+
apple_icon
184+
porcupine
185+
chicken
186+
stingray
187+
warplane
188+
windmill
189+
bamboo_slip
190+
wig
191+
flying_geckos
192+
stonechat
193+
haddock
194+
australian_terrier
195+
hover_board
196+
siamang
197+
canton_tower
198+
santa_sledge
199+
arch_bridge
200+
curlew
201+
sushi
202+
beet_root
203+
accordion
204+
leaf_egg
205+
stealth_aircraft
206+
stork
207+
bucket
208+
hawk
209+
chess_queen
210+
ocarina
211+
knife
212+
whippet
213+
cantilever_bridge
214+
may_bug
215+
wagtail
216+
leather_shoes
217+
wheelchair
218+
shumai
219+
speedboat
220+
vacuum_cup
221+
chess_knight
222+
pumpkin_pie
223+
wooden_spoon
224+
bamboo_dragonfly
225+
ganeva_chair
226+
soap
227+
clearwing_flyingfish
228+
pencil_sharpener1
229+
cricket
230+
photocopier
231+
nintendo_sp
232+
samarra_mosque
233+
clam
234+
charge_battery
235+
flying_frog
236+
ferrari911
237+
polo_shirt
238+
echidna
239+
coin
240+
tower_pisa

‎data/splits/fss/trn.txt

+520
Large diffs are not rendered by default.

‎data/splits/fss/val.txt

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
handcuff
2+
mortar
3+
matchstick
4+
wine_bottle
5+
dowitcher
6+
triumphal_arch
7+
gyromitra
8+
hatchet
9+
airliner
10+
broccoli
11+
olive
12+
pubg_lvl3backpack
13+
calculator
14+
toucan
15+
shovel
16+
sewing_machine
17+
icecream
18+
woodpecker
19+
pig
20+
relay_stick
21+
mcdonald_sign
22+
cpu
23+
peanut
24+
pumpkin
25+
sturgeon
26+
hammer
27+
hami_melon
28+
squirrel_monkey
29+
shuriken
30+
power_drill
31+
pingpong_ball
32+
crocodile
33+
carambola
34+
monarch_butterfly
35+
drum
36+
water_tower
37+
panda
38+
toilet_brush
39+
pay_phone
40+
yonex_icon
41+
cricketball
42+
revolver
43+
chimpanzee
44+
crab
45+
corn
46+
baseball
47+
rabbit
48+
croquet_ball
49+
artichoke
50+
abacus
51+
harp
52+
bell
53+
gas_tank
54+
scissors
55+
vase
56+
upright_piano
57+
typewriter
58+
bittern
59+
impala
60+
tray
61+
fire_hydrant
62+
beer_bottle
63+
sock
64+
soup_bowl
65+
spider
66+
cherry
67+
macaw
68+
toilet_seat
69+
fire_balloon
70+
french_ball
71+
fox_squirrel
72+
volleyball
73+
cornmeal
74+
folding_chair
75+
pubg_airdrop
76+
beagle
77+
skateboard
78+
narcissus
79+
whiptail
80+
cup
81+
arabian_camel
82+
badger
83+
stopwatch
84+
ab_wheel
85+
ox
86+
lettuce
87+
monocycle
88+
redshank
89+
vulture
90+
whistle
91+
smoothing_iron
92+
mashed_potato
93+
conveyor
94+
yoga_pad
95+
tow_truck
96+
siamese_cat
97+
cigar
98+
white_stork
99+
sniper_rifle
100+
stretcher
101+
tulip
102+
handkerchief
103+
basset
104+
iceberg
105+
gibbon
106+
lacewing
107+
thrush
108+
cheetah
109+
bighorn_sheep
110+
espresso_maker
111+
pretzel
112+
english_setter
113+
sandbar
114+
cheese
115+
daisy
116+
arctic_fox
117+
briard
118+
colubus
119+
balance_beam
120+
coffeepot
121+
soap_dispenser
122+
yawl
123+
consomme
124+
parking_meter
125+
cactus
126+
turnstile
127+
taro
128+
fire_screen
129+
digital_clock
130+
rose
131+
pomegranate
132+
bee_eater
133+
schooner
134+
ski_mask
135+
jay_bird
136+
plaice
137+
red_fox
138+
syringe
139+
camomile
140+
pickelhaube
141+
blenheim_spaniel
142+
pear
143+
parachute
144+
common_newt
145+
bowtie
146+
cigarette
147+
oscilloscope
148+
laptop
149+
african_crocodile
150+
apron
151+
coconut
152+
sandal
153+
kwanyin
154+
lion
155+
eel
156+
balloon
157+
crepe
158+
armadillo
159+
kazoo
160+
lemon
161+
spider_monkey
162+
tape_player
163+
ipod
164+
bee
165+
sea_cucumber
166+
suitcase
167+
television
168+
pillow
169+
banjo
170+
rock_snake
171+
partridge
172+
platypus
173+
lycaenid_butterfly
174+
pinecone
175+
conversion_plug
176+
wolf
177+
frying_pan
178+
timber_wolf
179+
bluetick
180+
crayon
181+
giant_schnauzer
182+
orang
183+
scarerow
184+
kobe_logo
185+
loguat
186+
saxophone
187+
ceiling_fan
188+
cardoon
189+
equestrian_helmet
190+
louvre_pyramid
191+
hotdog
192+
ironing_board
193+
razor
194+
nagoya_castle
195+
loggerhead_turtle
196+
lipstick
197+
cradle
198+
strongbox
199+
raven
200+
kit_fox
201+
albatross
202+
flat-coated_retriever
203+
beer_glass
204+
ice_lolly
205+
sungnyemun
206+
totem_pole
207+
vacuum
208+
bolete
209+
mango
210+
ginger
211+
weasel
212+
cabbage
213+
refrigerator
214+
school_bus
215+
hippo
216+
tiger_cat
217+
saltshaker
218+
piano_keyboard
219+
windsor_tie
220+
sea_urchin
221+
microsd
222+
barbell
223+
swim_ring
224+
bulbul_bird
225+
water_ouzel
226+
ac_ground
227+
sweatshirt
228+
umbrella
229+
hair_drier
230+
hammerhead_shark
231+
tomato
232+
projector
233+
cushion
234+
dishwasher
235+
three-toed_sloth
236+
tiger_shark
237+
har_gow
238+
baby
239+
thor's_hammer
240+
nike_logo

‎model/decoder.py

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class conv_block(nn.Module):
7+
def __init__(self,ch_in,ch_out):
8+
super(conv_block,self).__init__()
9+
self.conv = nn.Sequential(
10+
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
11+
nn.BatchNorm2d(ch_out),
12+
nn.ReLU(inplace=True),
13+
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
14+
nn.BatchNorm2d(ch_out),
15+
nn.ReLU(inplace=True)
16+
)
17+
18+
def forward(self,x):
19+
#print(x.shape)
20+
x = self.conv(x)
21+
return x
22+
23+
class up_conv(nn.Module):
24+
def __init__(self,ch_in,ch_out,kernel_size=3, stride=1, padding=1, groups=1):
25+
super(up_conv,self).__init__()
26+
self.up = nn.Sequential(
27+
nn.Upsample(scale_factor=2),
28+
nn.Conv2d(ch_in,ch_out,kernel_size=kernel_size,stride=stride,padding=padding,bias=True),
29+
nn.BatchNorm2d(ch_out),
30+
nn.ReLU(inplace=True)
31+
)
32+
33+
def forward(self,x):
34+
x = self.up(x)
35+
return x
36+
37+
38+
class Attention_block(nn.Module):
39+
def __init__(self,F_g,F_l,F_int):
40+
super(Attention_block,self).__init__()
41+
self.dim = F_l
42+
self.W_g = nn.Sequential(
43+
nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
44+
nn.BatchNorm2d(F_int)
45+
)
46+
47+
self.W_x = nn.Sequential(
48+
nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
49+
nn.BatchNorm2d(F_int)
50+
)
51+
52+
self.psi = nn.Sequential(
53+
nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
54+
nn.BatchNorm2d(1),
55+
nn.Sigmoid()
56+
)
57+
58+
self.relu = nn.ReLU(inplace=True)
59+
# self.conv1_1 = nn.Conv2d(2*F_l, F_l, kernel_size=1,stride=1,padding=0,bias=True)
60+
self.conv3d = nn.Conv3d(2, 1, 3, padding=1)
61+
62+
def forward(self,g,x, pad = (0, 1, 0, 1)):
63+
x = torch.concat([x[:,:self.dim].unsqueeze(dim = 1), x[:,self.dim:].unsqueeze(dim = 1)], dim=1)
64+
x = self.conv3d(x).squeeze(dim = 1)
65+
x = F.pad(x, pad, mode='replicate')
66+
g1 = self.W_g(g)
67+
x1 = self.W_x(x)
68+
psi = self.relu(g1+x1)
69+
psi = self.psi(psi)
70+
return x*psi
71+
72+
class ChannelAttention(nn.Module):
73+
def __init__(self, in_planes, ratio=16):
74+
super(ChannelAttention, self).__init__()
75+
self.in_planes = in_planes
76+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
77+
self.max_pool = nn.AdaptiveMaxPool2d(1)
78+
79+
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
80+
self.relu1 = nn.ReLU()
81+
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
82+
83+
self.sigmoid = nn.Sigmoid()
84+
85+
def forward(self, x):
86+
avg_pool_out = self.avg_pool(x)
87+
avg_out = self.fc2(self.relu1(self.fc1(avg_pool_out)))
88+
#print(x.shape)
89+
max_pool_out= self.max_pool(x) #torch.topk(x,3, dim=1).values
90+
91+
max_out = self.fc2(self.relu1(self.fc1(max_pool_out)))
92+
out = avg_out + max_out
93+
return self.sigmoid(out)
94+
95+
class SpatialAttention(nn.Module):
96+
def __init__(self, kernel_size=7):
97+
super(SpatialAttention, self).__init__()
98+
99+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
100+
padding = 3 if kernel_size == 7 else 1
101+
102+
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
103+
self.sigmoid = nn.Sigmoid()
104+
105+
def forward(self, x):
106+
avg_out = torch.mean(x, dim=1, keepdim=True)
107+
max_out, _ = torch.max(x, dim=1, keepdim=True)
108+
x = torch.cat([avg_out, max_out], dim=1)
109+
x = self.conv1(x)
110+
return self.sigmoid(x)
111+
112+
class AdaptiveLKA(nn.Module):
113+
def __init__(self, dim, use3d = False):
114+
super().__init__()
115+
self.dim = dim
116+
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
117+
self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
118+
self.conv1 = nn.Conv2d(dim, dim, 1)
119+
self.use3d = use3d
120+
if use3d:
121+
self.conv3d = nn.Conv3d(2, 1, 3, padding=1)
122+
123+
def forward(self, x):
124+
if self.use3d:
125+
x = torch.concat([x[:,:self.dim].unsqueeze(dim = 1), x[:,self.dim:].unsqueeze(dim = 1)], dim=1)
126+
x = self.conv3d(x).squeeze(dim = 1)
127+
u = x.clone()
128+
attn = self.conv0(x)
129+
attn = self.conv_spatial(attn)
130+
attn = self.conv1(attn)
131+
132+
return u * attn
133+
134+
class LKA_decoder(nn.Module):
135+
def __init__(self, channels=[2048,1024,512,256]):
136+
super(LKA_decoder,self).__init__()
137+
self.channels = channels
138+
self.Conv_1x1 = nn.Conv2d(2*channels[0],2*channels[0],kernel_size=1,stride=1,padding=0)
139+
self.ConvBlock4 = conv_block(ch_in=channels[0], ch_out=channels[0])
140+
141+
self.Up3 = up_conv(ch_in=channels[0],ch_out=channels[1])
142+
self.AG3 = Attention_block(F_g=channels[1],F_l=channels[1],F_int=channels[2])
143+
self.ConvBlock3 = conv_block(ch_in=channels[1], ch_out=channels[1])
144+
145+
self.Up2 = up_conv(ch_in=channels[1],ch_out=channels[2])
146+
self.AG2 = Attention_block(F_g=channels[2],F_l=channels[2],F_int=channels[3])
147+
self.ConvBlock2 = conv_block(ch_in=channels[2], ch_out=channels[2])
148+
149+
self.Up1 = up_conv(ch_in=channels[2],ch_out=channels[3])
150+
self.AG1 = Attention_block(F_g=channels[3],F_l=channels[3],F_int=int(channels[3]/2))
151+
self.ConvBlock1 = conv_block(ch_in=channels[3], ch_out=channels[3])
152+
153+
self.CA4 = ChannelAttention(channels[0])
154+
self.CA3 = ChannelAttention(channels[1])
155+
self.CA2 = ChannelAttention(channels[2])
156+
self.CA1 = ChannelAttention(channels[3])
157+
158+
self.ALKA4 = AdaptiveLKA(dim=channels[0], use3d = True)
159+
self.ALKA3 = AdaptiveLKA(dim=channels[1])
160+
self.ALKA2 = AdaptiveLKA(dim=channels[2])
161+
self.ALKA1 = AdaptiveLKA(dim=channels[3])
162+
163+
164+
self.SA = SpatialAttention()
165+
self.Upf = up_conv(ch_in=channels[3],ch_out=32)
166+
self.decoderf = nn.Sequential(nn.Conv2d(32, 16, (3, 3), padding=(1, 1), bias=True),
167+
nn.ReLU(),
168+
nn.Conv2d(16, 2, (3, 3), padding=(1, 1), bias=True))
169+
def forward(self,x, skips):
170+
d4 = self.Conv_1x1(x)
171+
# CAM4
172+
d4 = self.ALKA4(d4)
173+
174+
d4 = self.ConvBlock4(d4)
175+
176+
# upconv3
177+
d3 = self.Up3(d4)
178+
179+
# AG3
180+
x3 = self.AG3(g=d3,x=skips[0])
181+
# aggregate 3
182+
d3 = d3 + x3
183+
184+
185+
# CAM3
186+
d3 = self.ALKA3(d3)
187+
d3 = self.ConvBlock3(d3)
188+
189+
# upconv2
190+
d2 = self.Up2(d3)
191+
192+
# AG2
193+
x2 = self.AG2(g=d2,x=skips[1], pad = (0, 2, 0, 2))
194+
195+
# aggregate 2
196+
d2 = d2 + x2
197+
198+
# CAM2
199+
d2 = self.ALKA2(d2)
200+
201+
d2 = self.ConvBlock2(d2)
202+
203+
# upconv1
204+
d1 = self.Up1(d2)
205+
206+
#print(skips[2])
207+
# AG1
208+
x1 = self.AG1(g=d1,x=skips[2], pad = (0, 4, 0, 4))
209+
210+
# aggregate 1
211+
d1 = d1 + x1
212+
213+
# CAM1
214+
215+
d1 = self.ALKA1(d1)
216+
d1 = self.ConvBlock1(d1)
217+
218+
d1 = self.decoderf(self.Upf(d1))
219+
220+
return d1
221+

‎model/mymodel.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#python
2+
""" few shot network """
3+
from functools import reduce
4+
from operator import add
5+
from .decoder import LKA_decoder
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
from torchvision.models import resnet
10+
11+
class fewshotnet(nn.Module):
12+
def __init__(self, use_original_imgsize = False):
13+
super(fewshotnet, self).__init__()
14+
self.use_original_imgsize = use_original_imgsize
15+
self.backbone = resnet.resnet50(pretrained=True)
16+
self.feat_ids = [3,7,13,16]#list(range(2, 17))
17+
self.extract_feats = extract_feat_res
18+
nbottlenecks = [3, 4, 6, 3]
19+
self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
20+
self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
21+
self.stack_ids = torch.tensor(self.lids).bincount().__reversed__().cumsum(dim=0)[:3]
22+
self.backbone.eval()
23+
24+
self.cross_entropy_loss = nn.CrossEntropyLoss()
25+
self.decoder = LKA_decoder()
26+
27+
def forward(self, query_img, support_img, support_mask):
28+
with torch.no_grad():
29+
query_feats = self.extract_feats(query_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
30+
support_feats = self.extract_feats(support_img, self.backbone, self.feat_ids, self.bottleneck_ids, self.lids)
31+
support_feats = self.mask_feature(support_feats, support_mask.clone())
32+
33+
for idx, (feat_s, feat_q) in enumerate(zip(support_feats, query_feats)):
34+
query_feats[idx] = torch.concatenate([feat_s, feat_q], dim =1)
35+
36+
logit_mask = self.decoder(query_feats[3], [query_feats[2], query_feats[1], query_feats[0]])
37+
38+
if not self.use_original_imgsize:
39+
logit_mask = F.interpolate(logit_mask, support_img.size()[2:], mode='bilinear', align_corners=True)
40+
41+
return logit_mask
42+
43+
def mask_feature(self, features, support_mask):
44+
for idx, feature in enumerate(features):
45+
mask = F.interpolate(support_mask.unsqueeze(1).float(), feature.size()[2:], mode='bilinear', align_corners=True)
46+
temp = features[idx] * mask
47+
features[idx] = temp
48+
49+
50+
return features
51+
52+
def predict_mask_nshot(self, batch, nshot, thresh = 0.5):
53+
logit_mask_agg = 0
54+
for s_idx in range(nshot):
55+
logit_mask = self(batch['query_img'], batch['support_imgs'][:, s_idx], batch['support_masks'][:, s_idx])
56+
57+
if self.use_original_imgsize:
58+
org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()])
59+
logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True)
60+
61+
logit_mask_agg += logit_mask.argmax(dim=1).clone()
62+
if nshot == 1: return logit_mask_agg
63+
64+
# Average & quantize predictions given threshold (=0.5)
65+
bsz = logit_mask_agg.size(0)
66+
max_vote = logit_mask_agg.view(bsz, -1).max(dim=1)[0]
67+
max_vote = torch.stack([max_vote, torch.ones_like(max_vote).long()])
68+
max_vote = max_vote.max(dim=0)[0].view(bsz, 1, 1)
69+
pred_mask = logit_mask_agg.float() / max_vote
70+
pred_mask[pred_mask < thresh] = 0
71+
pred_mask[pred_mask >= thresh] = 1
72+
73+
return pred_mask
74+
75+
def compute_objective(self, logit_mask, gt_mask):
76+
bsz = logit_mask.size(0)
77+
logit_mask = logit_mask.view(bsz, 2, -1)
78+
gt_mask = gt_mask.view(bsz, -1).long()
79+
80+
return self.cross_entropy_loss(logit_mask, gt_mask)
81+
82+
def train_mode(self):
83+
self.train()
84+
self.backbone.eval()
85+
86+
def extract_feat_res(img, backbone, feat_ids, bottleneck_ids, lids):
87+
r""" Extract intermediate features from ResNet"""
88+
feats = []
89+
90+
# Layer 0
91+
feat = backbone.conv1.forward(img)
92+
feat = backbone.bn1.forward(feat)
93+
feat = backbone.relu.forward(feat)
94+
feat = backbone.maxpool.forward(feat)
95+
96+
# Layer 1-4
97+
for hid, (bid, lid) in enumerate(zip(bottleneck_ids, lids)):
98+
res = feat
99+
feat = backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
100+
feat = backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
101+
feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
102+
feat = backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
103+
feat = backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
104+
feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
105+
feat = backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
106+
feat = backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
107+
108+
if bid == 0:
109+
res = backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
110+
feat += res
111+
if hid + 1 in feat_ids:
112+
feats.append(feat.clone())
113+
feat = backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
114+
return feats

‎test.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
r""" Hypercorrelation Squeeze testing code """
2+
import argparse
3+
4+
import torch.nn.functional as F
5+
import torch.nn as nn
6+
import torch
7+
8+
from model.mymodel import fewshotnet
9+
from common.logger import Logger, AverageMeter
10+
from common.vis import Visualizer
11+
from common.evaluation import Evaluator
12+
from common import utils
13+
from data.dataset import FSSDataset
14+
import torch
15+
from torch.utils.data import Dataset,DataLoader
16+
import torch.nn as nn
17+
import numpy as np
18+
19+
20+
21+
def compute_miou(Es_mask, qmask):
22+
Es_mask, qmask = Es_mask.detach().cpu().numpy(), qmask.detach().cpu().numpy()
23+
ious = 0.0
24+
Es_mask = np.where(Es_mask> 0.5, 1. , 0.)
25+
for idx in range(Es_mask.shape[0]):
26+
notTrue = 1 - qmask[idx]
27+
union = np.sum(qmask[idx] + (notTrue * Es_mask[idx]))
28+
intersection = np.sum(qmask[idx] * Es_mask[idx])
29+
ious += (intersection / union)
30+
miou = (ious / Es_mask.shape[0])
31+
return miou
32+
33+
34+
35+
def test(model, dataloader, nshot):
36+
r""" Test HSNet """
37+
miou = 0.0
38+
# Freeze randomness during testing for reproducibility
39+
utils.fix_randseed(0)
40+
average_meter = AverageMeter(dataloader.dataset)
41+
for idx, batch in enumerate(dataloader):
42+
batch = utils.to_cuda(batch)
43+
pred_mask = model.module.predict_mask_nshot(batch, nshot=nshot)
44+
45+
miou += compute_miou(pred_mask, batch['query_mask'])
46+
# name = batch['query_name'][0][:-4]+'_pred.png'
47+
48+
# mask = pred_mask.detach().cpu().numpy()[0]
49+
50+
# from PIL import Image
51+
# from matplotlib import cm
52+
# from matplotlib import pyplot as plt
53+
54+
# plt.imsave(name, mask, cmap=cm.gray)
55+
56+
# assert pred_mask.size() == batch['query_mask'].size()
57+
58+
# # 2. Evaluate prediction
59+
# area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
60+
# average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
61+
# average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
62+
63+
64+
# # Visualize predictions
65+
# if Visualizer.visualize:
66+
# Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
67+
# batch['query_img'], batch['query_mask'],
68+
# pred_mask, batch['class_id'], idx,
69+
# area_inter[1].float() / area_union[1].float())
70+
71+
# # Write evaluation results
72+
# average_meter.write_result('Test', 0)
73+
# miou, fb_iou = average_meter.compute_iou()
74+
75+
return miou/(idx)*100.
76+
77+
78+
if __name__ == '__main__':
79+
80+
# Arguments parsing
81+
parser = argparse.ArgumentParser(description='Annotation free few-shot segmentation Pytorch Implementation')
82+
parser.add_argument('--datapath', type=str, default='D:/dataset/fewshot_data/')
83+
parser.add_argument('--benchmark', type=str, default='fss')
84+
parser.add_argument('--logpath', type=str, default='')
85+
parser.add_argument('--bsz', type=int, default=24)
86+
parser.add_argument('--nworker', type=int, default=0)
87+
parser.add_argument('--load', type=str, default='./logs/fss_weightsnew.pt')
88+
parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
89+
parser.add_argument('--nshot', type=int, default=1)
90+
parser.add_argument('--backbone', type=str, default='resnet50')
91+
parser.add_argument('--visualize', default=False, action='store_true')
92+
parser.add_argument('--use_original_imgsize', action='store_true')
93+
args = parser.parse_args()
94+
Logger.initialize(args, training=False)
95+
96+
# Model initialization
97+
model = fewshotnet()
98+
model.eval()
99+
Logger.log_params(model)
100+
101+
# Device setup
102+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
103+
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
104+
model = nn.DataParallel(model)
105+
model.to(device)
106+
107+
# Load trained model
108+
if args.load == '': raise Exception('Pretrained model not specified.')
109+
model.load_state_dict(torch.load(args.load))
110+
print('model created and weight file is loaded')
111+
112+
# Helper classes (for testing) initialization
113+
Evaluator.initialize()
114+
Visualizer.initialize(args.visualize)
115+
116+
# Dataset initialization
117+
FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
118+
dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
119+
120+
121+
# Test HSNet
122+
with torch.no_grad():
123+
test_mio = test(model, dataloader_test, args.nshot)
124+
print(f'Test MIO is:{test_mio}')

‎train.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
r""" Hypercorrelation Squeeze training (validation) code """
2+
import argparse
3+
4+
import torch.optim as optim
5+
import torch.nn as nn
6+
import torch
7+
8+
from model.mymodel import fewshotnet
9+
from common.logger import Logger, AverageMeter
10+
from common.evaluation import Evaluator
11+
from common import utils
12+
from data.dataset import FSSDataset
13+
14+
15+
def train(epoch, model, dataloader, optimizer, training):
16+
r""" Train HSNet """
17+
18+
# Force randomness during training / freeze randomness during testing
19+
utils.fix_randseed(None) if training else utils.fix_randseed(0)
20+
model.module.train_mode() if training else model.module.eval()
21+
average_meter = AverageMeter(dataloader.dataset)
22+
23+
for idx, batch in enumerate(dataloader):
24+
25+
# 1. Hypercorrelation Squeeze Networks forward pass
26+
batch = utils.to_cuda(batch)
27+
logit_mask = model(batch['query_img'], batch['support_imgs'].squeeze(1), batch['support_masks'].squeeze(1))
28+
pred_mask = logit_mask.argmax(dim=1)
29+
30+
# 2. Compute loss & update model parameters
31+
loss = model.module.compute_objective(logit_mask, batch['query_mask'])
32+
if training:
33+
optimizer.zero_grad()
34+
loss.backward()
35+
optimizer.step()
36+
37+
# 3. Evaluate prediction
38+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
39+
average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
40+
average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
41+
42+
# Write evaluation results
43+
average_meter.write_result('Training' if training else 'Validation', epoch)
44+
avg_loss = utils.mean(average_meter.loss_buf)
45+
miou, fb_iou = average_meter.compute_iou()
46+
47+
return avg_loss, miou, fb_iou
48+
49+
if __name__ == '__main__':
50+
51+
# Arguments parsing
52+
parser = argparse.ArgumentParser(description='Annotation free few-shot segmentation Pytorch Implementation')
53+
parser.add_argument('--datapath', type=str, default='D:/dataset/fewshot_data/')
54+
parser.add_argument('--benchmark', type=str, default='fss')
55+
parser.add_argument('--logpath', type=str, default='')
56+
parser.add_argument('--bsz', type=int, default=20)
57+
parser.add_argument('--lr', type=float, default=1e-3)
58+
parser.add_argument('--niter', type=int, default=2000)
59+
parser.add_argument('--nworker', type=int, default=8)
60+
parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
61+
parser.add_argument('--backbone', type=str, default='resnet50')
62+
args = parser.parse_args()
63+
Logger.initialize(args, training=True)
64+
65+
# Model initialization
66+
model = fewshotnet()
67+
Logger.log_params(model)
68+
69+
# Device setup
70+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
71+
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
72+
model = nn.DataParallel(model)
73+
model.to(device)
74+
75+
# Helper classes (for training) initialization
76+
optimizer = optim.Adam([{"params": model.parameters(), "lr": args.lr}])
77+
Evaluator.initialize()
78+
79+
# Dataset initialization
80+
FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=False)
81+
dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn')
82+
dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val')
83+
84+
# Train HSNet
85+
best_val_miou = float('-inf')
86+
best_val_loss = float('inf')
87+
for epoch in range(args.niter):
88+
89+
trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True)
90+
with torch.no_grad():
91+
val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
92+
93+
# Save the best model
94+
if val_miou > best_val_miou:
95+
best_val_miou = val_miou
96+
Logger.save_model_miou(model, epoch, val_miou)
97+
torch.save(model.state_dict(), './logs/fss_weights.pt')
98+
99+
100+
Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
101+
Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
102+
Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
103+
Logger.tbd_writer.flush()
104+
Logger.tbd_writer.close()
105+
Logger.info('==================== Finished Training ====================')

0 commit comments

Comments
 (0)
Please sign in to comment.