Skip to content

Commit a1c997b

Browse files
Davide Carosellifacebook-github-bot
Davide Caroselli
authored andcommitted
Memory-Mapped IndexedDataset implementation (#589)
Summary: Following discussion in #574: - Implemented MMapIndexedDataset and MMapIndexedDatasetBuilder compatible with IndexedDataset/IndexedDatasetBuilder - Update scripts/read_binarized.py to support new MMapIndexedDataset - Option '--raw-text' and '--lazy-load' replaced with '--dataset-impl' and moved the option definition custom task args to more high-level options.add_dataset_args() (more appropriate) - Implemented also utils functions in indexed_dataset: make_dataset(), dataset_exists() Pull Request resolved: #589 Differential Revision: D14597128 Pulled By: myleott fbshipit-source-id: 4e92d99920cbaa52cfe5a0f1f5d9ae5c92d4268e
1 parent e4edf27 commit a1c997b

11 files changed

+289
-123
lines changed

fairseq/data/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .fairseq_dataset import FairseqDataset
1010
from .backtranslation_dataset import BacktranslationDataset
1111
from .concat_dataset import ConcatDataset
12-
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
12+
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
1313
from .language_pair_dataset import LanguagePairDataset
1414
from .lm_context_window_dataset import LMContextWindowDataset
1515
from .monolingual_dataset import MonolingualDataset
@@ -39,6 +39,7 @@
3939
'IndexedRawTextDataset',
4040
'LanguagePairDataset',
4141
'LMContextWindowDataset',
42+
'MMapIndexedDataset',
4243
'MonolingualDataset',
4344
'NoisingDataset',
4445
'RoundRobinZipDatasets',

fairseq/data/indexed_dataset.py

+202-5
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,44 @@
44
# This source code is licensed under the license found in the LICENSE file in
55
# the root directory of this source tree. An additional grant of patent rights
66
# can be found in the PATENTS file in the same directory.
7-
87
import os
8+
import shutil
99
import struct
1010

1111
import numpy as np
1212
import torch
1313

1414

15+
def make_builder(out_file, impl):
16+
if impl == 'mmap':
17+
return MMapIndexedDatasetBuilder(out_file)
18+
else:
19+
return IndexedDatasetBuilder(out_file)
20+
21+
22+
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
23+
if impl == 'raw' and IndexedRawTextDataset.exists(path):
24+
assert dictionary is not None
25+
return IndexedRawTextDataset(path, dictionary)
26+
elif impl == 'lazy' and IndexedDataset.exists(path):
27+
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
28+
elif impl == 'cached' and IndexedDataset.exists(path):
29+
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
30+
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
31+
return MMapIndexedDataset(path)
32+
33+
return None
34+
35+
36+
def dataset_exists(path, impl):
37+
if impl == 'raw':
38+
return IndexedRawTextDataset.exists(path)
39+
elif impl == 'mmap':
40+
return MMapIndexedDataset.exists(path)
41+
else:
42+
return IndexedDataset.exists(path)
43+
44+
1545
def read_longs(f, n):
1646
a = np.empty(n, dtype=np.int64)
1747
f.readinto(a)
@@ -37,6 +67,7 @@ def code(dtype):
3767
for k in dtypes.keys():
3868
if dtypes[k] == dtype:
3969
return k
70+
raise ValueError(dtype)
4071

4172

4273
def index_file_path(prefix_path):
@@ -100,8 +131,8 @@ def __len__(self):
100131
@staticmethod
101132
def exists(path):
102133
return (
103-
os.path.exists(index_file_path(path)) and
104-
os.path.exists(data_file_path(path))
134+
os.path.exists(index_file_path(path)) and
135+
os.path.exists(data_file_path(path))
105136
)
106137

107138
@property
@@ -135,7 +166,7 @@ def prefetch(self, indices):
135166
for i in indices:
136167
self.cache_index[i] = ptx
137168
size = self.data_offsets[i + 1] - self.data_offsets[i]
138-
a = self.cache[ptx : ptx + size]
169+
a = self.cache[ptx: ptx + size]
139170
self.data_file.seek(self.data_offsets[i] * self.element_size)
140171
self.data_file.readinto(a)
141172
ptx += size
@@ -149,7 +180,7 @@ def __getitem__(self, i):
149180
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
150181
a = np.empty(tensor_size, dtype=self.dtype)
151182
ptx = self.cache_index[i]
152-
np.copyto(a, self.cache[ptx : ptx + a.size])
183+
np.copyto(a, self.cache[ptx: ptx + a.size])
153184
item = torch.from_numpy(a).long()
154185
if self.fix_lua_indexing:
155186
item -= 1 # subtract 1 for 0-based indexing
@@ -262,3 +293,169 @@ def finalize(self, index_file):
262293
write_longs(index, self.data_offsets)
263294
write_longs(index, self.sizes)
264295
index.close()
296+
297+
298+
def _warmup_mmap_file(path):
299+
with open(path, 'rb') as stream:
300+
while stream.read(100 * 1024 * 1024):
301+
pass
302+
303+
304+
class MMapIndexedDataset(torch.utils.data.Dataset):
305+
class Index(object):
306+
_HDR_MAGIC = b'MMIDIDX\x00\x00'
307+
308+
@classmethod
309+
def writer(cls, path, dtype):
310+
class _Writer(object):
311+
def __enter__(self):
312+
self._file = open(path, 'wb')
313+
314+
self._file.write(cls._HDR_MAGIC)
315+
self._file.write(struct.pack('<Q', 1))
316+
self._file.write(struct.pack('<B', code(dtype)))
317+
318+
return self
319+
320+
@staticmethod
321+
def _get_pointers(sizes):
322+
dtype_size = dtype().itemsize
323+
address = 0
324+
pointers = []
325+
326+
for size in sizes:
327+
pointers.append(address)
328+
address += size * dtype_size
329+
330+
return pointers
331+
332+
def write(self, sizes):
333+
pointers = self._get_pointers(sizes)
334+
335+
self._file.write(struct.pack('<Q', len(sizes)))
336+
337+
sizes = np.array(sizes, dtype=np.int32)
338+
self._file.write(sizes.tobytes(order='C'))
339+
del sizes
340+
341+
pointers = np.array(pointers, dtype=np.int64)
342+
self._file.write(pointers.tobytes(order='C'))
343+
del pointers
344+
345+
def __exit__(self, exc_type, exc_val, exc_tb):
346+
self._file.close()
347+
348+
return _Writer()
349+
350+
def __init__(self, path):
351+
with open(path, 'rb') as stream:
352+
magic_test = stream.read(9)
353+
assert self._HDR_MAGIC == magic_test
354+
version = struct.unpack('<Q', stream.read(8))
355+
assert (1,) == version
356+
357+
dtype_code, = struct.unpack('<B', stream.read(1))
358+
self._dtype = dtypes[dtype_code]
359+
self._dtype_size = self._dtype().itemsize
360+
361+
self._len = struct.unpack('<Q', stream.read(8))[0]
362+
offset = stream.tell()
363+
364+
_warmup_mmap_file(path)
365+
366+
self._bin_buffer = memoryview(np.memmap(path, mode='r', order='C'))
367+
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
368+
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
369+
offset=offset + self._sizes.nbytes)
370+
371+
@property
372+
def dtype(self):
373+
return self._dtype
374+
375+
@property
376+
def sizes(self):
377+
return self._sizes
378+
379+
def __getitem__(self, i):
380+
return self._pointers[i], self._sizes[i]
381+
382+
def __len__(self):
383+
return self._len
384+
385+
def __init__(self, path):
386+
super().__init__()
387+
388+
self._path = None
389+
self._index = None
390+
self._bin_buffer = None
391+
392+
self._do_init(path)
393+
394+
def __getstate__(self):
395+
return self._path
396+
397+
def __setstate__(self, state):
398+
self._do_init(state)
399+
400+
def _do_init(self, path):
401+
self._path = path
402+
self._index = self.Index(index_file_path(self._path))
403+
404+
_warmup_mmap_file(data_file_path(self._path))
405+
self._bin_buffer = memoryview(np.memmap(data_file_path(self._path), mode='r', order='C'))
406+
407+
def __len__(self):
408+
return len(self._index)
409+
410+
def __getitem__(self, i):
411+
ptr, size = self._index[i]
412+
tensor = torch.from_numpy(np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr))
413+
if tensor.dtype == torch.int64:
414+
return tensor
415+
else:
416+
return tensor.long()
417+
418+
@property
419+
def sizes(self):
420+
return self._index.sizes
421+
422+
@property
423+
def supports_prefetch(self):
424+
return False
425+
426+
@staticmethod
427+
def exists(path):
428+
return (
429+
os.path.exists(index_file_path(path)) and
430+
os.path.exists(data_file_path(path))
431+
)
432+
433+
434+
class MMapIndexedDatasetBuilder(object):
435+
def __init__(self, out_file, dtype=np.int64):
436+
self._data_file = open(out_file, 'wb')
437+
self._dtype = dtype
438+
self._sizes = []
439+
440+
def add_item(self, tensor):
441+
np_array = np.array(tensor.numpy(), dtype=self._dtype)
442+
self._data_file.write(np_array.tobytes(order='C'))
443+
self._sizes.append(np_array.size)
444+
445+
def merge_file_(self, another_file):
446+
# Concatenate index
447+
index = MMapIndexedDataset.Index(index_file_path(another_file))
448+
assert index.dtype == self._dtype
449+
450+
for size in index.sizes:
451+
self._sizes.append(size)
452+
453+
# Concatenate data
454+
with open(data_file_path(another_file), 'rb') as f:
455+
shutil.copyfileobj(f, self._data_file)
456+
457+
def finalize(self, index_file):
458+
self._data_file.close()
459+
460+
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
461+
index.write(self._sizes)

fairseq/options.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,8 @@ def add_preprocess_args(parser):
198198
help="number of source words to retain")
199199
group.add_argument("--alignfile", metavar="ALIGN", default=None,
200200
help="an alignment file (optional)")
201-
group.add_argument("--output-format", metavar="FORMAT", default="binary",
202-
choices=["binary", "raw"],
203-
help="output format (optional)")
201+
parser.add_argument('--dataset-impl', metavar="FORMAT", help='output dataset implementation',
202+
choices=['raw', 'lazy', 'cached', 'mmap'], default='cached')
204203
group.add_argument("--joined-dictionary", action="store_true",
205204
help="Generate joined dictionary")
206205
group.add_argument("--only-source", action="store_true",
@@ -226,6 +225,8 @@ def add_dataset_args(parser, train=False, gen=False):
226225
help='maximum number of sentences in a batch')
227226
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
228227
help='batch size will be a multiplier of this value')
228+
parser.add_argument('--dataset-impl', metavar="FORMAT", help='output dataset implementation',
229+
choices=['raw', 'lazy', 'cached', 'mmap'], default='cached')
229230
if train:
230231
group.add_argument('--train-subset', default='train', metavar='SPLIT',
231232
choices=['train', 'valid', 'test'],

fairseq/tasks/cross_lingual_lm.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
from fairseq.data import (
1919
ConcatDataset,
20-
IndexedCachedDataset,
21-
IndexedDataset,
22-
IndexedRawTextDataset,
20+
indexed_dataset,
2321
TokenBlockDataset,
2422
)
2523

@@ -118,14 +116,11 @@ def _load_single_lang_dataset(self, split, epoch):
118116
split_k = split + (str(k) if k > 0 else '')
119117
path = os.path.join(data_path, split_k)
120118

121-
if self.args.raw_text and IndexedRawTextDataset.exists(path):
122-
ds = IndexedRawTextDataset(path, self.dictionary)
123-
elif not self.args.raw_text and IndexedDataset.exists(path):
124-
if self.args.lazy_load:
125-
ds = IndexedDataset(path, fix_lua_indexing=True)
126-
else:
127-
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
128-
else:
119+
ds = indexed_dataset.make_dataset(
120+
path, impl=self.args.dataset_impl, fix_lua_indexing=True,
121+
dictionary=self.dictionary,
122+
)
123+
if ds is None:
129124
if k > 0:
130125
break
131126
else:

fairseq/tasks/language_modeling.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,19 @@
88
import itertools
99
import os
1010

11-
import torch
1211
import numpy as np
12+
import torch
1313

14+
from fairseq import utils
1415
from fairseq.data import (
1516
ConcatDataset,
1617
Dictionary,
17-
IndexedCachedDataset,
18-
IndexedDataset,
19-
IndexedRawTextDataset,
2018
MonolingualDataset,
2119
TokenBlockDataset,
2220
TransformEosDataset,
2321
TruncatedDictionary,
22+
indexed_dataset
2423
)
25-
2624
from . import FairseqTask, register_task
2725

2826

@@ -101,6 +99,13 @@ def setup_task(cls, args, **kwargs):
10199
Args:
102100
args (argparse.Namespace): parsed command-line arguments
103101
"""
102+
if getattr(args, 'raw_text', False):
103+
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
104+
args.dataset_impl = 'raw'
105+
elif getattr(args, 'lazy_load', False):
106+
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
107+
args.dataset_impl = 'lazy'
108+
104109
dictionary = None
105110
output_dictionary = None
106111
if args.data:
@@ -154,15 +159,10 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
154159
for k in itertools.count():
155160
split_k = split + (str(k) if k > 0 else '')
156161
path = os.path.join(data_path, split_k)
162+
ds = indexed_dataset.make_dataset(path, impl=self.args.dataset_impl,
163+
fix_lua_indexing=True, dictionary=self.dictionary)
157164

158-
if self.args.raw_text and IndexedRawTextDataset.exists(path):
159-
ds = IndexedRawTextDataset(path, self.dictionary)
160-
elif not self.args.raw_text and IndexedDataset.exists(path):
161-
if self.args.lazy_load:
162-
ds = IndexedDataset(path, fix_lua_indexing=True)
163-
else:
164-
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
165-
else:
165+
if ds is None:
166166
if k > 0:
167167
break
168168
else:

0 commit comments

Comments
 (0)