Skip to content

Commit 2be7b69

Browse files
committedOct 9, 2018
Added oneshot import package
1 parent d94c50e commit 2be7b69

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed
 

‎oneshot.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import os
2+
import keras
3+
import cv2
4+
from urllib.request import urlretrieve
5+
from tqdm import tqdm
6+
import pandas as pd
7+
import numpy as np
8+
import json
9+
10+
def triplet_loss(y_true, y_pred, alpha = 0.2):
11+
'''
12+
Returns the triplet loss.
13+
'''
14+
anchor, positive, negative = y_pred[0], y_pred[1], y_pred[2]
15+
pos_dist = tf.reduce_sum(tf.square(anchor - positive))
16+
neg_dist = tf.reduce_sum(tf.square(anchor - negative))
17+
basic_loss = pos_dist - neg_dist + alpha
18+
loss = tf.maximum(basic_loss, 0)
19+
return loss
20+
21+
22+
def init_model(input_shape=(200,200,3)):
23+
'''
24+
Initalizes the model and returns the model
25+
26+
Arguments:
27+
input_shape: shape of image input along with channels
28+
Default is (200,200,3) and if you wish to change, it should be greater than this.
29+
'''
30+
assert(input_shape[0]<200 or input_shape[1]<200), 'Shape cannot be less than 200'
31+
assert(input_shape[2]!=3), 'RGB channels required'
32+
RecModel = InceptionV3(input_shape=input_shape, weights=None, include_top=False)
33+
RecModel.compile(optimizer='adam', loss = triplet_loss, metrics = ['accuracy'])
34+
return RecModel
35+
36+
#Image encoding. The main aim of this function is to convert the images to encodings.
37+
def image_to_encoding(image_name, model):
38+
'''
39+
Images to encoding runs one forward propogation on the image and creates encodings.
40+
41+
Arguments:
42+
image_name: Image name is the name of image to encode with extension
43+
model: The model on which we have to forward propogate
44+
'''
45+
46+
#Read the image and convert BGR to RGB
47+
a = cv2.imread('./images/'+str(image_name),cv2.COLOR_BGR2RGB)
48+
#Generate the image padding using np.zeros
49+
b = np.zeros((200,200,3), dtype='uint8')
50+
#Add the slice of image over the padding
51+
b[:a.shape[0],:a.shape[1],:]+=a
52+
#Convert to batch
53+
b = b.reshape(1,200,200,3)
54+
#Forward propagate through the model network
55+
embedding = model.predict_on_batch(b)
56+
return embedding
57+
58+
def create_encodings(path_to_images='./images'):
59+
'''
60+
This generates embeddings and take 0.15s for each element of dataset
61+
So use accordingly.
62+
63+
Arguments:
64+
path_to_image = pass the path to image string to the function.
65+
'''
66+
imgs = os.listdir('./images')
67+
database={}
68+
for i,img in enumerate(sorted(imgs)):
69+
database[img.split('.')[0]]=image_to_encoding(img, RecModel)
70+
return database
71+
72+
#Function to verify the images with the identity
73+
def verify(image,identity, database, model):
74+
"""
75+
Function that verifies Redundant images
76+
77+
Arguments:
78+
image_path -- path to an image
79+
identity -- string, productId you'd like to verify the identity.
80+
database -- python dictionary mapping Products to their encodings (vectors).
81+
model -- your ResNet model instance in Keras
82+
83+
Returns:
84+
dist -- distance between the image_path and the image of "identity" in the database.
85+
same -- True, if the images are same. False otherwise.
86+
"""
87+
88+
# Step 1: Compute the encoding for the image. Use image_to_encoding()
89+
encoding = image_to_encoding(image, model)
90+
# Step 2: Compute distance with identity's image (≈ 1 line)
91+
dist = np.linalg.norm(database[identity] - encoding)
92+
93+
# Step 3: Open the door if dist < 0.7, else don't open (≈ 3 lines)
94+
if dist < 1:
95+
same = True
96+
else:
97+
same = False
98+
return dist, same
99+
100+
if __name__ == '__main__':
101+
if not os.path.exists('./dataset'):
102+
os.mkdir('./dataset')
103+
if not os.path.exists('./images/'):
104+
try:
105+
link_zip='https://transfer.sh/ShBts/images_fkcdn.zip'
106+
link_tunics='https://transfer.sh/8gQj4/tunics.csv'
107+
if os.path.exists('./dataset/images_fkcdn.zip'):
108+
os.system('unzip ./dataset/images_fkcdn.zip')
109+
else:
110+
urlretrieve(link_zip,'./dataset/images_fkcdn.zip')
111+
urlretrieve(link_tunics,'./dataset/tunics.csv')
112+
except Exception as e:
113+
print(e)
114+
115+
RecModel=init_model()
116+
database=create_encodings()
117+
df = pd.read_csv('tunics.csv')
118+
productFamily=json.load(open('pf.json'))
119+
imgs=sorted(os.listdir('./images'))
120+
#Compare each image with all other image in array
121+
checks={}
122+
ptr=0
123+
124+
# This loops compares the image with the preceeding images and
125+
# not with the already visited images to overcome redundancy
126+
for i in range(2):
127+
anchor_image=imgs[i]
128+
anchor_name=imgs[i].split('.')[0]
129+
checks[anchor_name]=[]
130+
for j in range(i+1,len(imgs)):
131+
target_image=imgs[j]
132+
target_name=imgs[j].split('.')[0]
133+
if target_name not in productFamily[anchor_name] and verify(anchor_image, target_name, database, RecModel)[1] :
134+
checks[anchor_name].append(target_name)
135+
print(ptr)
136+
ptr+=1
137+
138+
# Dump the json to the temp server
139+
import json
140+
with open('output.json','w') as f:
141+
json.dump(checks,f)

0 commit comments

Comments
 (0)
Please sign in to comment.