4
4
# This source code is licensed under the license found in the LICENSE file in
5
5
# the root directory of this source tree. An additional grant of patent rights
6
6
# can be found in the PATENTS file in the same directory.
7
-
8
7
import os
8
+ import shutil
9
9
import struct
10
10
11
11
import numpy as np
12
12
import torch
13
13
14
14
15
+ def make_builder (out_file , impl ):
16
+ if impl == 'mmap' :
17
+ return MMapIndexedDatasetBuilder (out_file )
18
+ else :
19
+ return IndexedDatasetBuilder (out_file )
20
+
21
+
22
+ def make_dataset (path , impl , fix_lua_indexing = False , dictionary = None ):
23
+ if impl == 'raw' and IndexedRawTextDataset .exists (path ):
24
+ assert dictionary is not None
25
+ return IndexedRawTextDataset (path , dictionary )
26
+ elif impl == 'lazy' and IndexedDataset .exists (path ):
27
+ return IndexedDataset (path , fix_lua_indexing = fix_lua_indexing )
28
+ elif impl == 'cached' and IndexedDataset .exists (path ):
29
+ return IndexedCachedDataset (path , fix_lua_indexing = fix_lua_indexing )
30
+ elif impl == 'mmap' and MMapIndexedDataset .exists (path ):
31
+ return MMapIndexedDataset (path )
32
+
33
+ return None
34
+
35
+
36
+ def dataset_exists (path , impl ):
37
+ if impl == 'raw' :
38
+ return IndexedRawTextDataset .exists (path )
39
+ elif impl == 'mmap' :
40
+ return MMapIndexedDataset .exists (path )
41
+ else :
42
+ return IndexedDataset .exists (path )
43
+
44
+
15
45
def read_longs (f , n ):
16
46
a = np .empty (n , dtype = np .int64 )
17
47
f .readinto (a )
@@ -37,6 +67,7 @@ def code(dtype):
37
67
for k in dtypes .keys ():
38
68
if dtypes [k ] == dtype :
39
69
return k
70
+ raise ValueError (dtype )
40
71
41
72
42
73
def index_file_path (prefix_path ):
@@ -100,8 +131,8 @@ def __len__(self):
100
131
@staticmethod
101
132
def exists (path ):
102
133
return (
103
- os .path .exists (index_file_path (path )) and
104
- os .path .exists (data_file_path (path ))
134
+ os .path .exists (index_file_path (path )) and
135
+ os .path .exists (data_file_path (path ))
105
136
)
106
137
107
138
@property
@@ -135,7 +166,7 @@ def prefetch(self, indices):
135
166
for i in indices :
136
167
self .cache_index [i ] = ptx
137
168
size = self .data_offsets [i + 1 ] - self .data_offsets [i ]
138
- a = self .cache [ptx : ptx + size ]
169
+ a = self .cache [ptx : ptx + size ]
139
170
self .data_file .seek (self .data_offsets [i ] * self .element_size )
140
171
self .data_file .readinto (a )
141
172
ptx += size
@@ -149,7 +180,7 @@ def __getitem__(self, i):
149
180
tensor_size = self .sizes [self .dim_offsets [i ]:self .dim_offsets [i + 1 ]]
150
181
a = np .empty (tensor_size , dtype = self .dtype )
151
182
ptx = self .cache_index [i ]
152
- np .copyto (a , self .cache [ptx : ptx + a .size ])
183
+ np .copyto (a , self .cache [ptx : ptx + a .size ])
153
184
item = torch .from_numpy (a ).long ()
154
185
if self .fix_lua_indexing :
155
186
item -= 1 # subtract 1 for 0-based indexing
@@ -262,3 +293,169 @@ def finalize(self, index_file):
262
293
write_longs (index , self .data_offsets )
263
294
write_longs (index , self .sizes )
264
295
index .close ()
296
+
297
+
298
+ def _warmup_mmap_file (path ):
299
+ with open (path , 'rb' ) as stream :
300
+ while stream .read (100 * 1024 * 1024 ):
301
+ pass
302
+
303
+
304
+ class MMapIndexedDataset (torch .utils .data .Dataset ):
305
+ class Index (object ):
306
+ _HDR_MAGIC = b'MMIDIDX\x00 \x00 '
307
+
308
+ @classmethod
309
+ def writer (cls , path , dtype ):
310
+ class _Writer (object ):
311
+ def __enter__ (self ):
312
+ self ._file = open (path , 'wb' )
313
+
314
+ self ._file .write (cls ._HDR_MAGIC )
315
+ self ._file .write (struct .pack ('<Q' , 1 ))
316
+ self ._file .write (struct .pack ('<B' , code (dtype )))
317
+
318
+ return self
319
+
320
+ @staticmethod
321
+ def _get_pointers (sizes ):
322
+ dtype_size = dtype ().itemsize
323
+ address = 0
324
+ pointers = []
325
+
326
+ for size in sizes :
327
+ pointers .append (address )
328
+ address += size * dtype_size
329
+
330
+ return pointers
331
+
332
+ def write (self , sizes ):
333
+ pointers = self ._get_pointers (sizes )
334
+
335
+ self ._file .write (struct .pack ('<Q' , len (sizes )))
336
+
337
+ sizes = np .array (sizes , dtype = np .int32 )
338
+ self ._file .write (sizes .tobytes (order = 'C' ))
339
+ del sizes
340
+
341
+ pointers = np .array (pointers , dtype = np .int64 )
342
+ self ._file .write (pointers .tobytes (order = 'C' ))
343
+ del pointers
344
+
345
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
346
+ self ._file .close ()
347
+
348
+ return _Writer ()
349
+
350
+ def __init__ (self , path ):
351
+ with open (path , 'rb' ) as stream :
352
+ magic_test = stream .read (9 )
353
+ assert self ._HDR_MAGIC == magic_test
354
+ version = struct .unpack ('<Q' , stream .read (8 ))
355
+ assert (1 ,) == version
356
+
357
+ dtype_code , = struct .unpack ('<B' , stream .read (1 ))
358
+ self ._dtype = dtypes [dtype_code ]
359
+ self ._dtype_size = self ._dtype ().itemsize
360
+
361
+ self ._len = struct .unpack ('<Q' , stream .read (8 ))[0 ]
362
+ offset = stream .tell ()
363
+
364
+ _warmup_mmap_file (path )
365
+
366
+ self ._bin_buffer = memoryview (np .memmap (path , mode = 'r' , order = 'C' ))
367
+ self ._sizes = np .frombuffer (self ._bin_buffer , dtype = np .int32 , count = self ._len , offset = offset )
368
+ self ._pointers = np .frombuffer (self ._bin_buffer , dtype = np .int64 , count = self ._len ,
369
+ offset = offset + self ._sizes .nbytes )
370
+
371
+ @property
372
+ def dtype (self ):
373
+ return self ._dtype
374
+
375
+ @property
376
+ def sizes (self ):
377
+ return self ._sizes
378
+
379
+ def __getitem__ (self , i ):
380
+ return self ._pointers [i ], self ._sizes [i ]
381
+
382
+ def __len__ (self ):
383
+ return self ._len
384
+
385
+ def __init__ (self , path ):
386
+ super ().__init__ ()
387
+
388
+ self ._path = None
389
+ self ._index = None
390
+ self ._bin_buffer = None
391
+
392
+ self ._do_init (path )
393
+
394
+ def __getstate__ (self ):
395
+ return self ._path
396
+
397
+ def __setstate__ (self , state ):
398
+ self ._do_init (state )
399
+
400
+ def _do_init (self , path ):
401
+ self ._path = path
402
+ self ._index = self .Index (index_file_path (self ._path ))
403
+
404
+ _warmup_mmap_file (data_file_path (self ._path ))
405
+ self ._bin_buffer = memoryview (np .memmap (data_file_path (self ._path ), mode = 'r' , order = 'C' ))
406
+
407
+ def __len__ (self ):
408
+ return len (self ._index )
409
+
410
+ def __getitem__ (self , i ):
411
+ ptr , size = self ._index [i ]
412
+ tensor = torch .from_numpy (np .frombuffer (self ._bin_buffer , dtype = self ._index .dtype , count = size , offset = ptr ))
413
+ if tensor .dtype == torch .int64 :
414
+ return tensor
415
+ else :
416
+ return tensor .long ()
417
+
418
+ @property
419
+ def sizes (self ):
420
+ return self ._index .sizes
421
+
422
+ @property
423
+ def supports_prefetch (self ):
424
+ return False
425
+
426
+ @staticmethod
427
+ def exists (path ):
428
+ return (
429
+ os .path .exists (index_file_path (path )) and
430
+ os .path .exists (data_file_path (path ))
431
+ )
432
+
433
+
434
+ class MMapIndexedDatasetBuilder (object ):
435
+ def __init__ (self , out_file , dtype = np .int64 ):
436
+ self ._data_file = open (out_file , 'wb' )
437
+ self ._dtype = dtype
438
+ self ._sizes = []
439
+
440
+ def add_item (self , tensor ):
441
+ np_array = np .array (tensor .numpy (), dtype = self ._dtype )
442
+ self ._data_file .write (np_array .tobytes (order = 'C' ))
443
+ self ._sizes .append (np_array .size )
444
+
445
+ def merge_file_ (self , another_file ):
446
+ # Concatenate index
447
+ index = MMapIndexedDataset .Index (index_file_path (another_file ))
448
+ assert index .dtype == self ._dtype
449
+
450
+ for size in index .sizes :
451
+ self ._sizes .append (size )
452
+
453
+ # Concatenate data
454
+ with open (data_file_path (another_file ), 'rb' ) as f :
455
+ shutil .copyfileobj (f , self ._data_file )
456
+
457
+ def finalize (self , index_file ):
458
+ self ._data_file .close ()
459
+
460
+ with MMapIndexedDataset .Index .writer (index_file , self ._dtype ) as index :
461
+ index .write (self ._sizes )
0 commit comments