Skip to content

Commit 210c3f9

Browse files
committed
package it
1 parent a46fb35 commit 210c3f9

13 files changed

+336
-86
lines changed

Diff for: .github/python-publish.yml

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# This workflows will upload a Python Package using Twine when a release is created
2+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3+
4+
name: Upload Python Package
5+
6+
on:
7+
release:
8+
types: [created]
9+
10+
jobs:
11+
deploy:
12+
13+
runs-on: ubuntu-latest
14+
15+
steps:
16+
- uses: actions/checkout@v2
17+
- name: Set up Python
18+
uses: actions/setup-python@v2
19+
with:
20+
python-version: '3.x'
21+
- name: Install dependencies
22+
run: |
23+
python -m pip install --upgrade pip
24+
pip install setuptools wheel twine
25+
- name: Build and publish
26+
env:
27+
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29+
run: |
30+
python setup.py sdist bdist_wheel
31+
twine upload dist/*

Diff for: .gitignore

+11-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1-
.env
1+
.env
2+
*.egg-info
3+
.vscode
4+
.env
5+
__pycache__
6+
myimglist.txt
7+
.ipynb_checkpoints
8+
output_folder
9+
indice_folder
10+
image_folder
11+
cat

Diff for: .gitpod.DockerFile

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
FROM gitpod/workspace-full:latest
2+
3+
RUN apt-get update && apt-get install -y python3-opencv

Diff for: .gitpod.yml

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
image:
2+
file: .gitpod.DockerFile

Diff for: HISTORY.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## 1.0.0
2+
3+
* it works

Diff for: README.md

+32-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# clip-retrieval
2+
[![pypi](https://img.shields.io/pypi/v/clip-retrieval.svg)](https://pypi.python.org/pypi/clip-retrieval)
3+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rom1504/clip-retrieval/blob/master/notebook/clip-retrieval-getting-started.ipynb)
4+
[![Try it on gitpod](https://img.shields.io/badge/try-on%20gitpod-brightgreen.svg)](https://gitpod.io/#https://github.com/rom1504/clip-retrieval)
5+
26
Easily computing clip embeddings and building a clip retrieval system with them.
37

48
* clip batch allows you to quickly (1500 sample/s on a 3080) compute image and text embeddings and indices
@@ -9,18 +13,23 @@ Easily computing clip embeddings and building a clip retrieval system with them.
913
End to end this make it possible to build a simple semantic search system.
1014
Interested to learn about semantic search in general ? You can read by [medium post](https://rom1504.medium.com/semantic-search-with-embeddings-index-anything-8fb18556443c) on the topic.
1115

16+
## Install
17+
18+
pip install clip-retrieval
19+
1220
## clip batch
1321

14-
First install it by running:
22+
Get some images in an `example_folder`, for example by doing:
1523
```
16-
python3 -m venv .env
17-
source .env/bin/activate
18-
pip install -U pip
19-
pip install clip-anytorch faiss-cpu fire
24+
pip install img2dataset
25+
echo 'https://placekitten.com/200/305' >> myimglist.txt
26+
echo 'https://placekitten.com/200/304' >> myimglist.txt
27+
echo 'https://placekitten.com/200/303' >> myimglist.txt
28+
img2dataset --url_list=myimglist.txt --output_folder=image_folder --thread_count=64 --image_size=256
2029
```
30+
You can also put text files with the same names as the images in that folder, to get the text embeddings.
2131

22-
Then put some images in a `example_folder` and some text with the same name (or use --enable_text=False) then
23-
* `python clip_batch.py --dataset_path example_folder --output_folder output_folder`
32+
Then run `clip-retrieval batch --dataset_path image_folder --output_folder indice_folder`
2433

2534
Output folder will contain:
2635
* description_list containing the list of caption line by line
@@ -33,24 +42,16 @@ Output folder will contain:
3342
## Clip filter
3443

3544
Once the embeddings are computed, you may want to filter out the data by a specific query.
36-
For that you can run `python clip_filter.py --query "dog" --output_folder "dog/" --indice_name "example_index"`
45+
For that you can run `clip-retrieval filter --query "cat" --output_folder "cat/" --indice_folder "indice_folder"`
3746
It will copy the 100 best images for this query in the output folder.
3847
Using the `--num_results` or `--threshold` may be helpful to refine the filter
3948

4049
## Clip back
4150

42-
First install it by running:
43-
```bash
44-
python3 -m venv .env
45-
source .env/bin/activate
46-
pip install -U pip
47-
pip install clip-anytorch faiss-cpu fire flask flask_cors flask_restful
48-
```
49-
5051
Then run (output_folder is the output of clip batch)
5152
```bash
5253
echo '{"example_index": "output_folder"}' > indices_paths.json
53-
python clip_back.py 1234
54+
clip-retrieval back --port 1234 --indices-paths indices_paths.json
5455
```
5556

5657
At this point you have a simple flask server running on port 1234 and that can answer these queries:
@@ -78,4 +79,17 @@ and returns:
7879
"text": "some result text"
7980
}
8081
]
81-
```
82+
```
83+
84+
## For development
85+
86+
Either locally, or in [gitpod](https://gitpod.io/#https://github.com/rom1504/img2dataset) (do `export PIP_USER=false` there)
87+
88+
Setup a virtualenv:
89+
90+
```
91+
python3 -m venv .env
92+
source .env/bin/activate
93+
pip install -U pip
94+
pip install -e .
95+
```

Diff for: clip_retrieval/__init__.py

Whitespace-only changes.

Diff for: clip_retrieval/cli.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from clip_retrieval.clip_back import clip_back
2+
from clip_retrieval.clip_batch import clip_batch
3+
from clip_retrieval.clip_filter import clip_filter
4+
import fire
5+
import logging
6+
7+
8+
def main():
9+
"""Main entry point"""
10+
fire.Fire(
11+
{
12+
"back": clip_back,
13+
"batch": clip_batch,
14+
"filter": clip_filter
15+
}
16+
)

Diff for: clip_back.py renamed to clip_retrieval/clip_back.py

+52-43
Original file line numberDiff line numberDiff line change
@@ -11,49 +11,23 @@
1111
from PIL import Image
1212
import base64
1313
import os
14+
import fire
1415

1516

16-
device = "cuda" if torch.cuda.is_available() else "cpu"
17-
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
18-
19-
indices = json.load(open("indices_paths.json"))
20-
21-
indices_loaded = {}
22-
23-
for name, indice_folder in indices.items():
24-
image_present = os.path.exists(indice_folder+"/image_list")
25-
text_present = os.path.exists(indice_folder+"/description_list")
26-
if image_present:
27-
with open(indice_folder+"/image_list") as f:
28-
image_list = f.read().split("\n")
29-
image_index = faiss.read_index(indice_folder+"/image.index")
30-
else:
31-
image_list = None
32-
image_index = None
33-
if text_present:
34-
with open(indice_folder+"/description_list") as f:
35-
description_list = f.read().split("\n")
36-
text_index = faiss.read_index(indice_folder+"/text.index")
37-
else:
38-
description_list = None
39-
text_index = None
40-
indices_loaded[name]={
41-
'image_list': image_list,
42-
'description_list': description_list,
43-
'image_index': image_index,
44-
'text_index': text_index
45-
}
46-
4717
class Health(Resource):
4818
def get(self):
4919
return "ok"
5020

5121
class IndicesList(Resource):
52-
def get(self):
53-
return list(indices.keys())
22+
def get(self, **kwargs):
23+
return list(kwargs['indices'].keys())
5424

5525
class KnnService(Resource):
56-
def post(self):
26+
def post(self, **kwargs):
27+
indices_loaded = kwargs['indices_loaded']
28+
device = kwargs['device']
29+
model = kwargs['model']
30+
preprocess = kwargs['preprocess']
5731
json_data = request.get_json(force=True)
5832
text_input = json_data["text"] if "text" in json_data else None
5933
image_input = json_data["image"] if "image" in json_data else None
@@ -91,15 +65,50 @@ def post(self):
9165
img.save(buffered, format="JPEG")
9266
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
9367
results.append({"image": img_str, "text": description})
94-
return results
95-
68+
return results
9669

97-
app = Flask(__name__)
98-
api = Api(app)
99-
api.add_resource(IndicesList, '/indices-list')
100-
api.add_resource(KnnService, '/knn-service')
101-
api.add_resource(Health, '/')
10270

103-
if __name__ == '__main__':
71+
def clip_back(indices_paths="indices_paths.json", port=1234):
72+
device = "cuda" if torch.cuda.is_available() else "cpu"
73+
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
74+
75+
indices = json.load(open(indices_paths))
76+
77+
indices_loaded = {}
78+
79+
for name, indice_folder in indices.items():
80+
image_present = os.path.exists(indice_folder+"/image_list")
81+
text_present = os.path.exists(indice_folder+"/description_list")
82+
if image_present:
83+
with open(indice_folder+"/image_list") as f:
84+
image_list = [x for x in f.read().split("\n") if x !=""]
85+
image_index = faiss.read_index(indice_folder+"/image.index")
86+
else:
87+
image_list = None
88+
image_index = None
89+
if text_present:
90+
with open(indice_folder+"/description_list") as f:
91+
description_list = [x for x in f.read().split("\n") if x !=""]
92+
text_index = faiss.read_index(indice_folder+"/text.index")
93+
else:
94+
description_list = None
95+
text_index = None
96+
indices_loaded[name]={
97+
'image_list': image_list,
98+
'description_list': description_list,
99+
'image_index': image_index,
100+
'text_index': text_index
101+
}
102+
103+
app = Flask(__name__)
104+
api = Api(app)
105+
api.add_resource(IndicesList, '/indices-list', resource_class_kwargs={'indices': indices})
106+
api.add_resource(KnnService, '/knn-service', resource_class_kwargs={'indices_loaded': indices_loaded, 'device': device, \
107+
'model': model, 'preprocess': preprocess})
108+
api.add_resource(Health, '/')
104109
CORS(app)
105-
app.run(host="0.0.0.0", port=int(sys.argv[1]), debug=False)
110+
app.run(host="0.0.0.0", port=port, debug=False)
111+
112+
113+
if __name__ == '__main__':
114+
fire.Fire(clip_back)

Diff for: clip_batch.py renamed to clip_retrieval/clip_batch.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,22 @@ def __init__(self,
3434
if self.enable_text:
3535
text_files = [*path.glob('**/*.txt')]
3636
text_files = {text_file.stem: text_file for text_file in text_files}
37+
if len(text_files) == 0:
38+
self.enable_text = False
3739
if self.enable_image:
3840
image_files = [
3941
*path.glob('**/*.png'), *path.glob('**/*.jpg'),
4042
*path.glob('**/*.jpeg'), *path.glob('**/*.bmp')
4143
]
4244
image_files = {image_file.stem: image_file for image_file in image_files}
45+
if len(image_files) == 0:
46+
self.enable_image = False
4347

44-
if enable_text and enable_image:
48+
if self.enable_text and self.enable_image:
4549
keys = (image_files.keys() & text_files.keys())
46-
elif enable_text:
50+
elif self.enable_text:
4751
keys = text_files.keys()
48-
elif enable_image:
52+
elif self.enable_image:
4953
keys = image_files.keys()
5054

5155
self.keys = list(keys)
@@ -63,27 +67,35 @@ def __len__(self):
6367
def __getitem__(self, ind):
6468
key = self.keys[ind]
6569

70+
output = {}
71+
6672
if self.enable_image:
6773
image_file = self.image_files[key]
6874
image_tensor = self.image_transform(PIL.Image.open(image_file))
75+
output["image_filename"] = str(image_file)
76+
output["image_tensor"] = image_tensor
6977

7078

7179
if self.enable_text:
7280
text_file = self.text_files[key]
7381
descriptions = text_file.read_text().split('\n')
7482
description = descriptions[self.description_index]
7583
tokenized_text = self.tokenizer(description)
84+
output["text_tokens"] = tokenized_text
85+
output["text"] = description
7686

77-
return {"image_tensor": image_tensor, "text_tokens": tokenized_text, "image_filename": str(image_file), "text": description}
87+
return output
7888

7989

80-
def main(dataset_path, output_folder, batch_size=256, num_prepro_workers=8, description_index=0, enable_text=True, enable_image=True):
90+
def clip_batch(dataset_path, output_folder, batch_size=256, num_prepro_workers=8, description_index=0, enable_text=True, enable_image=True):
8191
device = "cuda" if torch.cuda.is_available() else "cpu"
8292
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
8393
if not os.path.exists(output_folder):
8494
os.mkdir(output_folder)
85-
data = DataLoader(ImageDataset(preprocess, dataset_path, description_index=description_index, enable_text=enable_text, enable_image=enable_image), \
86-
batch_size=batch_size, shuffle=False, num_workers=num_prepro_workers, pin_memory=True, prefetch_factor=2)
95+
dataset = ImageDataset(preprocess, dataset_path, description_index=description_index, enable_text=enable_text, enable_image=enable_image)
96+
enable_text = dataset.enable_text
97+
enable_image = dataset.enable_image
98+
data = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_prepro_workers, pin_memory=True, prefetch_factor=2)
8799
if enable_image:
88100
image_embeddings = []
89101
image_names = []
@@ -94,12 +106,12 @@ def main(dataset_path, output_folder, batch_size=256, num_prepro_workers=8, desc
94106
for i, item in enumerate(tqdm(data)):
95107
with torch.no_grad():
96108
if enable_image:
97-
image_features = model.encode_image(item["image_tensor"].cuda())
109+
image_features = model.encode_image(item["image_tensor"].to(device))
98110
image_features /= image_features.norm(dim=-1, keepdim=True)
99111
image_embeddings.append(image_features.cpu().numpy())
100112
image_names.extend(item["image_filename"])
101113
if enable_text:
102-
text_features = model.encode_text(item["text_tokens"].cuda())
114+
text_features = model.encode_text(item["text_tokens"].to(device))
103115
text_features /= text_features.norm(dim=-1, keepdim=True)
104116
text_embeddings.append(text_features.cpu().numpy())
105117
descriptions.extend(item["text"])
@@ -125,4 +137,4 @@ def main(dataset_path, output_folder, batch_size=256, num_prepro_workers=8, desc
125137
faiss.write_index(text_index, output_folder +"/text.index")
126138

127139
if __name__ == '__main__':
128-
fire.Fire(main)
140+
fire.Fire(clip_batch)

0 commit comments

Comments
 (0)