Skip to content

Commit dd46e46

Browse files
committed
add tests and fix bugs
1 parent ab4e4a3 commit dd46e46

File tree

6 files changed

+111
-29
lines changed

6 files changed

+111
-29
lines changed

Diff for: .gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ image_folder
1111
cat
1212
embedding_folder
1313
index_folder
14-
indices_paths.json
14+
indices_paths.json
15+
.coverage

Diff for: clip_retrieval/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""clip retrieval"""
2+
3+
from .clip_back import clip_back
4+
from .clip_filter import clip_filter
5+
from .clip_index import clip_index
6+
from .clip_inference import clip_inference

Diff for: clip_retrieval/clip_index.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,21 @@
44
import fire
55
import os
66
from distutils.dir_util import copy_tree
7-
from multiprocessing import Process
7+
import logging
8+
9+
10+
LOGGER = logging.getLogger(__name__)
811

912

1013
def quantize(emb_folder, index_folder, index_name, max_index_memory_usage, current_memory_available, nb_cores):
1114
"""calls autofaiss to build an index"""
1215
try:
16+
LOGGER.debug(f"starting index {index_name}")
1317
if os.path.exists(emb_folder):
18+
LOGGER.debug(
19+
f"embedding path exist, building index {index_name}"
20+
f"using embeddings {emb_folder} ; saving in {index_folder}"
21+
)
1422
build_index(
1523
embeddings_path=emb_folder,
1624
index_path=index_folder + "/" + index_name + ".index",
@@ -19,8 +27,10 @@ def quantize(emb_folder, index_folder, index_name, max_index_memory_usage, curre
1927
current_memory_available=current_memory_available,
2028
nb_cores=nb_cores,
2129
)
30+
LOGGER.debug(f"index {index_name} done")
2231
except Exception as e: # pylint: disable=broad-except
23-
print(e)
32+
LOGGER.exception(f"index {index_name} failed")
33+
raise e
2434

2535

2636
def clip_index(
@@ -34,32 +44,22 @@ def clip_index(
3444
nb_cores=None,
3545
):
3646
"""indexes clip embeddings using autofaiss"""
37-
p = Process(
38-
target=quantize,
39-
args=(
40-
embeddings_folder + "/" + image_subfolder,
41-
index_folder,
42-
"image",
43-
max_index_memory_usage,
44-
current_memory_available,
45-
nb_cores,
46-
),
47+
quantize(
48+
embeddings_folder + "/" + image_subfolder,
49+
index_folder,
50+
"image",
51+
max_index_memory_usage,
52+
current_memory_available,
53+
nb_cores,
4754
)
48-
p.start()
49-
p.join()
50-
p = Process(
51-
target=quantize,
52-
args=(
53-
embeddings_folder + "/" + text_subfolder,
54-
index_folder,
55-
"text",
56-
max_index_memory_usage,
57-
current_memory_available,
58-
nb_cores,
59-
),
55+
quantize(
56+
embeddings_folder + "/" + text_subfolder,
57+
index_folder,
58+
"text",
59+
max_index_memory_usage,
60+
current_memory_available,
61+
nb_cores,
6062
)
61-
p.start()
62-
p.join()
6363
if copy_metadata:
6464
copy_tree(embeddings_folder + "/metadata", index_folder + "/metadata")
6565

Diff for: clip_retrieval/clip_inference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def __write_batch(self):
230230
img_emb_mat = np.concatenate(self.image_embeddings)
231231
output_path_img = self.img_emb_folder + "/img_emb_" + str(self.batch_num)
232232

233-
with self.fs.open(output_path_img, "wb") as f:
233+
with self.fs.open(output_path_img + ".npy", "wb") as f:
234234
npb = BytesIO()
235235
np.save(npb, img_emb_mat)
236236
f.write(npb.getbuffer())
@@ -242,7 +242,7 @@ def __write_batch(self):
242242
text_emb_mat = np.concatenate(self.text_embeddings)
243243
output_path_text = self.text_emb_folder + "/text_emb_" + str(self.batch_num)
244244

245-
with self.fs.open(output_path_text, "wb") as f:
245+
with self.fs.open(output_path_text + ".npy", "wb") as f:
246246
npb = BytesIO()
247247
np.save(npb, text_emb_mat)
248248
f.write(npb.getbuffer())

Diff for: pytest.ini

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
[pytest]
3+
log_cli = 1
4+
log_cli_level = DEBUG

Diff for: tests/test_end2end.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from img2dataset import download
2+
from clip_retrieval import clip_inference
3+
from clip_retrieval import clip_index
4+
import os
5+
import pandas as pd
6+
import shutil
7+
8+
test_list = [
9+
["first", "https://placekitten.com/400/600"],
10+
["second", "https://placekitten.com/200/300"],
11+
["third", "https://placekitten.com/300/200"],
12+
["fourth", "https://placekitten.com/400/400"],
13+
["fifth", "https://placekitten.com/200/200"],
14+
[None, "https://placekitten.com/200/200"],
15+
]
16+
17+
18+
def generate_parquet(output_file):
19+
df = pd.DataFrame(test_list, columns=["caption", "url"])
20+
df.to_parquet(output_file)
21+
22+
23+
def test_end2end():
24+
current_folder = os.path.dirname(__file__)
25+
test_folder = current_folder + "/" + "test_folder"
26+
if os.path.exists(test_folder):
27+
shutil.rmtree(test_folder)
28+
if not os.path.exists(test_folder):
29+
os.mkdir(test_folder)
30+
url_list_name = os.path.join(test_folder, "url_list")
31+
image_folder_name = os.path.join(test_folder, "images")
32+
33+
url_list_name += ".parquet"
34+
generate_parquet(url_list_name)
35+
36+
download(
37+
url_list_name,
38+
image_size=256,
39+
output_folder=image_folder_name,
40+
thread_count=32,
41+
input_format="parquet",
42+
output_format="webdataset",
43+
url_col="url",
44+
caption_col="caption",
45+
)
46+
47+
assert os.path.exists(image_folder_name)
48+
49+
embeddings_folder = os.path.join(test_folder, "embeddings")
50+
51+
clip_inference(
52+
input_dataset=f"{image_folder_name}/00000.tar",
53+
output_folder=embeddings_folder,
54+
input_format="webdataset",
55+
enable_metadata=True,
56+
write_batch_size=100000,
57+
batch_size=512,
58+
cache_path=None,
59+
)
60+
61+
assert os.path.exists(embeddings_folder)
62+
63+
index_folder = os.path.join(test_folder, "index")
64+
65+
os.mkdir(index_folder)
66+
67+
clip_index(embeddings_folder, index_folder=index_folder)
68+
69+
assert os.path.exists(index_folder + "/image.index")
70+
assert os.path.exists(index_folder + "/text.index")
71+

0 commit comments

Comments
 (0)