@@ -62,14 +62,14 @@ class ImageDataset(Dataset):
62
62
def __init__ (
63
63
self ,
64
64
preprocess ,
65
+ tokenizer ,
65
66
folder ,
66
67
enable_text = True ,
67
68
enable_image = True ,
68
69
enable_metadata = False ,
69
70
input_sampler = lambda a : a ,
70
71
):
71
72
super ().__init__ ()
72
- import clip # pylint: disable=import-outside-toplevel
73
73
74
74
self .keys , text_files , image_files , metadata_files = folder_to_keys (
75
75
folder , enable_text , enable_image , enable_metadata
@@ -80,7 +80,7 @@ def __init__(
80
80
self .enable_metadata = enable_metadata
81
81
keys_set = set (self .keys )
82
82
if self .enable_text :
83
- self .tokenizer = lambda text : clip . tokenize ([text ], truncate = True )[0 ]
83
+ self .tokenizer = lambda text : tokenizer ([text ])[0 ]
84
84
self .text_files = {k : v for k , v in text_files .items () if k in keys_set }
85
85
if self .enable_image :
86
86
self .image_files = {k : v for k , v in image_files .items () if k in keys_set }
@@ -125,6 +125,7 @@ def __getitem__(self, ind):
125
125
def create_webdataset (
126
126
urls ,
127
127
image_transform ,
128
+ tokenizer ,
128
129
enable_text = True ,
129
130
enable_image = True ,
130
131
image_key = "jpg" ,
@@ -134,15 +135,14 @@ def create_webdataset(
134
135
input_sampler = lambda a : a ,
135
136
):
136
137
"""Create a WebDataset reader, it can read a webdataset of image, text and json"""
137
- import clip # pylint: disable=import-outside-toplevel
138
138
import webdataset as wds # pylint: disable=import-outside-toplevel
139
139
140
140
urls = input_sampler (urls )
141
141
142
142
dataset = wds .WebDataset (urls , cache_dir = cache_path , cache_size = 10 ** 10 , handler = wds .handlers .warn_and_continue )
143
143
144
- def tokenizer (text ):
145
- return clip . tokenize ([text ], truncate = True )[0 ]
144
+ def _tokenizer (text ):
145
+ return tokenizer ([text ])[0 ]
146
146
147
147
def filter_dataset (item ):
148
148
if enable_text and caption_key not in item :
@@ -167,7 +167,7 @@ def preprocess_dataset(item):
167
167
if enable_text :
168
168
text = item [caption_key ]
169
169
caption = text .decode ("utf-8" )
170
- tokenized_text = tokenizer (caption )
170
+ tokenized_text = _tokenizer (caption )
171
171
output ["text_tokens" ] = tokenized_text
172
172
output ["text" ] = caption
173
173
@@ -207,6 +207,7 @@ def __init__(
207
207
self ,
208
208
sampler ,
209
209
preprocess ,
210
+ tokenizer ,
210
211
input_dataset ,
211
212
batch_size ,
212
213
num_prepro_workers ,
@@ -215,7 +216,9 @@ def __init__(
215
216
enable_metadata = False ,
216
217
) -> None :
217
218
super ().__init__ ()
218
- dataset = get_image_dataset ()(preprocess , input_dataset , enable_text , enable_image , enable_metadata , sampler )
219
+ dataset = get_image_dataset ()(
220
+ preprocess , tokenizer , input_dataset , enable_text , enable_image , enable_metadata , sampler
221
+ )
219
222
self .dataloader = dataset_to_dataloader (dataset , batch_size , num_prepro_workers , "files" )
220
223
221
224
def __iter__ (self ):
@@ -230,6 +233,7 @@ def __init__(
230
233
self ,
231
234
sampler ,
232
235
preprocess ,
236
+ tokenizer ,
233
237
input_dataset ,
234
238
batch_size ,
235
239
num_prepro_workers ,
@@ -244,6 +248,7 @@ def __init__(
244
248
dataset = create_webdataset (
245
249
input_dataset ,
246
250
preprocess ,
251
+ tokenizer ,
247
252
enable_text = enable_text ,
248
253
enable_image = enable_image ,
249
254
image_key = wds_image_key ,
0 commit comments