1
- #!pip install clip-by-openai torch faiss
1
+ #!pip install clip-by-openai faiss-cpu fire
2
2
import torch
3
3
import clip
4
4
from PIL import Image
@@ -29,7 +29,6 @@ def __init__(self,
29
29
):
30
30
super ().__init__ ()
31
31
path = Path (folder )
32
- self .model = model
33
32
self .enable_text = enable_text
34
33
self .enable_image = enable_image
35
34
@@ -76,13 +75,14 @@ def __getitem__(self, ind):
76
75
description = descriptions [self .description_index ]
77
76
tokenized_text = self .tokenizer ([description [:255 ]])[0 ]
78
77
79
- return {"image_tensor" : image_tensor , "text_tokens" : tokenized_text , "image_path " : str (image_file ), "text" : description }
78
+ return {"image_tensor" : image_tensor , "text_tokens" : tokenized_text , "image_filename " : str (image_file ), "text" : description }
80
79
81
80
82
- def main (dataset_path , output_folder , batch_size = 256 , num_prepro_workers = 32 , description_index = 0 , enable_text = True , enable_image = True ):
81
+ def main (dataset_path , output_folder , batch_size = 256 , num_prepro_workers = 8 , description_index = 0 , enable_text = True , enable_image = True ):
83
82
device = "cuda" if torch .cuda .is_available () else "cpu"
84
- model , preprocess = clip .load ("ViT-B/32" , device = device )
85
- os .mkdir (output_folder )
83
+ model , preprocess = clip .load ("ViT-B/32" , device = device , jit = False )
84
+ if not os .path .exists (output_folder ):
85
+ os .mkdir (output_folder )
86
86
data = DataLoader (ImageDataset (preprocess , dataset_path , description_index = description_index , enable_text = enable_text , enable_image = enable_image ), \
87
87
batch_size = batch_size , shuffle = False , num_workers = num_prepro_workers , pin_memory = True , prefetch_factor = 2 )
88
88
if enable_image :
@@ -101,7 +101,7 @@ def main(dataset_path, output_folder, batch_size=256, num_prepro_workers=32, des
101
101
if enable_text :
102
102
text_features = model .encode_text (item ["text_tokens" ].cuda ())
103
103
text_embeddings .append (text_features .cpu ().numpy ())
104
- descriptions .extend (item ["description " ])
104
+ descriptions .extend (item ["text " ])
105
105
106
106
if enable_image :
107
107
img_emb_mat = np .concatenate (image_embeddings )
0 commit comments