Skip to content

Commit 2c56c07

Browse files
authored
Support get_tokenizer in clip back inf (rom1504#236)
* Support get_tokenizer in clip back inf * remove truncate * fix * test fixes
1 parent 7a4959d commit 2c56c07

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

Diff for: clip_retrieval/clip_inference/reader.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ class ImageDataset(Dataset):
6262
def __init__(
6363
self,
6464
preprocess,
65+
tokenizer,
6566
folder,
6667
enable_text=True,
6768
enable_image=True,
6869
enable_metadata=False,
6970
input_sampler=lambda a: a,
7071
):
7172
super().__init__()
72-
import clip # pylint: disable=import-outside-toplevel
7373

7474
self.keys, text_files, image_files, metadata_files = folder_to_keys(
7575
folder, enable_text, enable_image, enable_metadata
@@ -80,7 +80,7 @@ def __init__(
8080
self.enable_metadata = enable_metadata
8181
keys_set = set(self.keys)
8282
if self.enable_text:
83-
self.tokenizer = lambda text: clip.tokenize([text], truncate=True)[0]
83+
self.tokenizer = lambda text: tokenizer([text])[0]
8484
self.text_files = {k: v for k, v in text_files.items() if k in keys_set}
8585
if self.enable_image:
8686
self.image_files = {k: v for k, v in image_files.items() if k in keys_set}
@@ -125,6 +125,7 @@ def __getitem__(self, ind):
125125
def create_webdataset(
126126
urls,
127127
image_transform,
128+
tokenizer,
128129
enable_text=True,
129130
enable_image=True,
130131
image_key="jpg",
@@ -134,15 +135,14 @@ def create_webdataset(
134135
input_sampler=lambda a: a,
135136
):
136137
"""Create a WebDataset reader, it can read a webdataset of image, text and json"""
137-
import clip # pylint: disable=import-outside-toplevel
138138
import webdataset as wds # pylint: disable=import-outside-toplevel
139139

140140
urls = input_sampler(urls)
141141

142142
dataset = wds.WebDataset(urls, cache_dir=cache_path, cache_size=10**10, handler=wds.handlers.warn_and_continue)
143143

144-
def tokenizer(text):
145-
return clip.tokenize([text], truncate=True)[0]
144+
def _tokenizer(text):
145+
return tokenizer([text])[0]
146146

147147
def filter_dataset(item):
148148
if enable_text and caption_key not in item:
@@ -167,7 +167,7 @@ def preprocess_dataset(item):
167167
if enable_text:
168168
text = item[caption_key]
169169
caption = text.decode("utf-8")
170-
tokenized_text = tokenizer(caption)
170+
tokenized_text = _tokenizer(caption)
171171
output["text_tokens"] = tokenized_text
172172
output["text"] = caption
173173

@@ -207,6 +207,7 @@ def __init__(
207207
self,
208208
sampler,
209209
preprocess,
210+
tokenizer,
210211
input_dataset,
211212
batch_size,
212213
num_prepro_workers,
@@ -215,7 +216,9 @@ def __init__(
215216
enable_metadata=False,
216217
) -> None:
217218
super().__init__()
218-
dataset = get_image_dataset()(preprocess, input_dataset, enable_text, enable_image, enable_metadata, sampler)
219+
dataset = get_image_dataset()(
220+
preprocess, tokenizer, input_dataset, enable_text, enable_image, enable_metadata, sampler
221+
)
219222
self.dataloader = dataset_to_dataloader(dataset, batch_size, num_prepro_workers, "files")
220223

221224
def __iter__(self):
@@ -230,6 +233,7 @@ def __init__(
230233
self,
231234
sampler,
232235
preprocess,
236+
tokenizer,
233237
input_dataset,
234238
batch_size,
235239
num_prepro_workers,
@@ -244,6 +248,7 @@ def __init__(
244248
dataset = create_webdataset(
245249
input_dataset,
246250
preprocess,
251+
tokenizer,
247252
enable_text=enable_text,
248253
enable_image=enable_image,
249254
image_key=wds_image_key,

Diff for: clip_retrieval/clip_inference/worker.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def worker(
4949
print(f"dataset is {len(input_dataset)}", flush=True)
5050

5151
def reader_builder(sampler):
52-
_, preprocess, _ = load_clip(
52+
_, preprocess, tokenizer = load_clip(
5353
clip_model=clip_model,
5454
use_jit=use_jit,
5555
warmup_batch_size=batch_size,
@@ -59,6 +59,7 @@ def reader_builder(sampler):
5959
return FilesReader(
6060
sampler,
6161
preprocess,
62+
tokenizer,
6263
input_dataset,
6364
batch_size,
6465
num_prepro_workers,
@@ -70,6 +71,7 @@ def reader_builder(sampler):
7071
return WebdatasetReader(
7172
sampler,
7273
preprocess,
74+
tokenizer,
7375
input_dataset,
7476
batch_size,
7577
num_prepro_workers,

Diff for: tests/test_clip_inference/test_reader.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_reader(file_format):
1717
input_dataset = [tar_folder + "/image1.tar", tar_folder + "/image2.tar"]
1818
batch_size = 2
1919
num_prepro_workers = 2
20-
_, preprocess, _ = load_clip(warmup_batch_size=batch_size)
20+
_, preprocess, tokenizer = load_clip(warmup_batch_size=batch_size)
2121

2222
output_partition_count = 2
2323
actual_values = []
@@ -27,6 +27,7 @@ def test_reader(file_format):
2727
reader = FilesReader(
2828
sampler,
2929
preprocess,
30+
tokenizer,
3031
input_dataset,
3132
batch_size,
3233
num_prepro_workers,
@@ -38,6 +39,7 @@ def test_reader(file_format):
3839
reader = WebdatasetReader(
3940
sampler,
4041
preprocess,
42+
tokenizer,
4143
input_dataset,
4244
batch_size,
4345
num_prepro_workers,

Diff for: tests/test_clip_inference/test_runner.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ def test_runner():
2121
with tempfile.TemporaryDirectory() as tmpdir:
2222

2323
def reader_builder(sampler):
24-
_, preprocess, _ = load_clip(warmup_batch_size=batch_size)
24+
_, preprocess, tokenizer = load_clip(warmup_batch_size=batch_size)
2525
return FilesReader(
2626
sampler,
2727
preprocess,
28+
tokenizer,
2829
folder,
2930
batch_size,
3031
num_prepro_workers,

0 commit comments

Comments
 (0)