Skip to content

Commit f09db01

Browse files
authored
Fix small bugs with async map (#7445)
* fix async map resuming * fix with_indices * fix tests * fix tests * again
1 parent 67ffdfb commit f09db01

File tree

2 files changed

+108
-85
lines changed

2 files changed

+108
-85
lines changed

src/datasets/iterable_dataset.py

+86-83
Original file line numberDiff line numberDiff line change
@@ -1076,15 +1076,17 @@ def _iter(self):
10761076
num_examples_to_skip = 0
10771077
iterator = iter(self.ex_iterable)
10781078

1079+
# We use the same logic as in Dataset.map, but with less features/formatting
1080+
# since they're handled by FormattedExamplesIterable
1081+
10791082
if self.formatting:
10801083
formatter = get_formatter(self.formatting.format_type)
1081-
format_dict = (
1082-
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
1083-
)
1084+
format_dict = formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else None
10841085
else:
10851086
format_dict = None
10861087

10871088
def iter_batched_inputs():
1089+
nonlocal current_idx
10881090
for key, example in iterator:
10891091
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
10901092
iterator_batch = (
@@ -1104,17 +1106,21 @@ def iter_batched_inputs():
11041106
): # ignore last batch
11051107
return
11061108
batch = _examples_to_batch(examples)
1109+
# we need to format here in case we need to stack tensors together
11071110
batch = format_dict(batch) if format_dict else batch
11081111
indices = [current_idx + i for i in range(len(key_examples_list))]
1112+
current_idx += len(indices)
11091113
yield indices, (key, batch)
11101114

11111115
def iter_inputs():
1116+
nonlocal current_idx
11121117
for key, example in iterator:
11131118
# If not batched, we can apply the transform and yield the example directly
11141119
# first copy the example, since we might drop some keys
11151120
example = dict(example)
1116-
example = format_dict(example) if format_dict else example
1117-
yield current_idx, (key, example)
1121+
# no need to do formatting here
1122+
current_idx += 1
1123+
yield current_idx - 1, (key, example)
11181124

11191125
def validate_function_output(processed_inputs):
11201126
if self.batched and processed_inputs:
@@ -1147,17 +1153,7 @@ def prepare_outputs(key_example, inputs, processed_inputs):
11471153
if processed_inputs is key_example[1] and c in processed_inputs:
11481154
del processed_inputs[c]
11491155
transformed_inputs = {**inputs, **processed_inputs}
1150-
if self.features:
1151-
for c in self.features.keys():
1152-
if c not in transformed_inputs:
1153-
transformed_inputs[c] = (
1154-
[None] * len(transformed_inputs[next(iter(processed_inputs))]) if self.batched else None
1155-
)
1156-
transformed_inputs = (
1157-
self.features.decode_batch(transformed_inputs)
1158-
if self.batched
1159-
else self.features.decode_example(transformed_inputs)
1160-
)
1156+
# no need to do features decoding here
11611157
return transformed_inputs
11621158

11631159
def apply_function(key_example, indices):
@@ -1185,6 +1181,11 @@ def iter_outputs():
11851181
nonlocal tasks, loop
11861182
inputs_iterator = iter_batched_inputs() if self.batched else iter_inputs()
11871183
if inspect.iscoroutinefunction(self.function):
1184+
if self._state_dict:
1185+
previous_state = self.ex_iterable.state_dict()
1186+
self._state_dict["previous_state"] = previous_state
1187+
previous_state_task = None
1188+
previous_state_example_idx = self._state_dict["previous_state_example_idx"]
11881189
indices: Union[list[int], list[list[int]]] = []
11891190
for i, key_example in inputs_iterator:
11901191
indices.append(i)
@@ -1198,42 +1199,57 @@ def iter_outputs():
11981199
done, pending = loop.run_until_complete(
11991200
asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
12001201
)
1202+
if len(tasks) >= 10 * config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
1203+
loop.run_until_complete(tasks[0])
12011204
# yield finished tasks
12021205
while tasks and tasks[0].done():
1203-
yield indices.pop(0), tasks.pop(0).result()
1206+
i, task = indices.pop(0), tasks.pop(0)
1207+
yield i, task.result()
1208+
if self._state_dict and task is previous_state_task:
1209+
self._state_dict["previous_state"] = previous_state
1210+
self._state_dict["num_examples_since_previous_state"] = 0
1211+
self._state_dict["previous_state_example_idx"] = previous_state_example_idx
1212+
previous_state, previous_state_task = None, None
1213+
# checkpoint
1214+
if self._state_dict and previous_state_task is None and tasks:
1215+
previous_state = self.ex_iterable.state_dict()
1216+
previous_state_task = tasks[-1]
1217+
previous_state_example_idx = current_idx
12041218
while tasks:
12051219
yield indices[0], loop.run_until_complete(tasks[0])
12061220
indices.pop(0), tasks.pop(0)
12071221
else:
1208-
for i, key_example in inputs_iterator:
1209-
yield i, apply_function(key_example, i)
1210-
1211-
try:
1212-
if self.batched:
12131222
if self._state_dict:
1214-
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
1215-
self._state_dict["num_examples_since_previous_state"] = 0
1216-
self._state_dict["previous_state_example_idx"] = current_idx
1217-
for key, transformed_batch in iter_outputs():
1218-
# yield one example at a time from the transformed batch
1219-
for example in _batch_to_examples(transformed_batch):
1220-
current_idx += 1
1221-
if self._state_dict:
1222-
self._state_dict["num_examples_since_previous_state"] += 1
1223-
if num_examples_to_skip > 0:
1224-
num_examples_to_skip -= 1
1225-
continue
1226-
yield key, example
1227-
if self._state_dict:
1223+
if self.batched:
12281224
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
12291225
self._state_dict["num_examples_since_previous_state"] = 0
12301226
self._state_dict["previous_state_example_idx"] = current_idx
1231-
else:
1232-
for key, transformed_example in iter_outputs():
1233-
current_idx += 1
1227+
for i, key_example in inputs_iterator:
12341228
if self._state_dict:
1235-
self._state_dict["previous_state_example_idx"] += 1
1236-
yield key, transformed_example
1229+
if not self.batched:
1230+
self._state_dict["previous_state_example_idx"] = current_idx
1231+
yield i, apply_function(key_example, i)
1232+
if self._state_dict:
1233+
if self.batched:
1234+
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
1235+
self._state_dict["num_examples_since_previous_state"] = 0
1236+
self._state_dict["previous_state_example_idx"] = current_idx
1237+
1238+
try:
1239+
outputs = iter_outputs()
1240+
if self.batched:
1241+
outputs = (
1242+
(key, transformed_example)
1243+
for key, transformed_batch in outputs
1244+
for transformed_example in _batch_to_examples(transformed_batch)
1245+
)
1246+
for key, transformed_example in outputs:
1247+
if self._state_dict and self._state_dict["previous_state"] is not None:
1248+
self._state_dict["num_examples_since_previous_state"] += 1
1249+
if num_examples_to_skip > 0:
1250+
num_examples_to_skip -= 1
1251+
continue
1252+
yield key, transformed_example
12371253
except (Exception, KeyboardInterrupt):
12381254
if loop:
12391255
logger.debug(f"Canceling {len(tasks)} async tasks.")
@@ -1800,7 +1816,7 @@ def _init_state_dict(self) -> dict:
18001816

18011817
def __iter__(self):
18021818
if not self.formatting or self.formatting.is_table:
1803-
formatter = PythonFormatter()
1819+
formatter = PythonFormatter(features=self._features if not self.ex_iterable.is_typed else None)
18041820
else:
18051821
formatter = get_formatter(
18061822
self.formatting.format_type,
@@ -1817,15 +1833,17 @@ def __iter__(self):
18171833
format_dict = (
18181834
formatter.recursive_tensorize
18191835
if isinstance(formatter, TensorFormatter)
1820-
else cast_to_python_objects # cast in case features is None
1836+
else None # cast in case features is None
18211837
)
18221838
for key, example in self.ex_iterable:
18231839
# don't apply feature types if already applied by ex_iterable (e.g. in case of chained with_format)
18241840
if self.features and not self.ex_iterable.is_typed:
18251841
example = _apply_feature_types_on_example(
18261842
example, self.features, token_per_repo_id=self.token_per_repo_id
18271843
)
1828-
yield key, format_dict(example)
1844+
if format_dict:
1845+
example = format_dict(example)
1846+
yield key, example
18291847

18301848
def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]:
18311849
if not self.features:
@@ -2049,7 +2067,7 @@ def __setstate__(self, d):
20492067
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)
20502068

20512069
def _head(self, n=5):
2052-
return _examples_to_batch(list(self.take(n)))
2070+
return next(iter(self.iter(batch_size=n)))
20532071

20542072
@property
20552073
def epoch(self) -> int:
@@ -2111,15 +2129,8 @@ def _iter_pytorch(self):
21112129
if self._starting_state_dict:
21122130
ex_iterable.load_state_dict(self._starting_state_dict)
21132131

2114-
if self._formatting:
2115-
formatter = get_formatter(self._formatting.format_type, features=self.features)
2116-
format_dict = (
2117-
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
2118-
)
2119-
else:
2120-
format_dict = None
2121-
21222132
if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
2133+
formatter = get_formatter(self._formatting.format_type, features=self.features)
21232134
if ex_iterable.iter_arrow:
21242135
iterator = ex_iterable.iter_arrow()
21252136
else:
@@ -2129,13 +2140,8 @@ def _iter_pytorch(self):
21292140
return
21302141
else:
21312142
for key, example in ex_iterable:
2132-
if self.features and not ex_iterable.is_typed:
2133-
# `IterableDataset` automatically fills missing columns with None.
2134-
# This is done with `_apply_feature_types_on_example`.
2135-
example = _apply_feature_types_on_example(
2136-
example, self.features, token_per_repo_id=self._token_per_repo_id
2137-
)
2138-
yield format_dict(example) if format_dict else example
2143+
# no need to format thanks to FormattedExamplesIterable
2144+
yield example
21392145
logger.debug(
21402146
f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.num_shards} shards."
21412147
)
@@ -2191,6 +2197,14 @@ def _prepare_ex_iterable_for_iteration(
21912197
)
21922198
ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank)
21932199

2200+
if self._formatting or (self.features and ex_iterable.features != self.features):
2201+
ex_iterable = FormattedExamplesIterable(
2202+
ex_iterable,
2203+
formatting=self._formatting,
2204+
features=self.features,
2205+
token_per_repo_id=self._token_per_repo_id,
2206+
)
2207+
21942208
self._state_dict = ex_iterable._init_state_dict()
21952209
if self._starting_state_dict:
21962210
ex_iterable.load_state_dict(self._starting_state_dict)
@@ -2207,15 +2221,8 @@ def __iter__(self):
22072221
return
22082222

22092223
ex_iterable = self._prepare_ex_iterable_for_iteration()
2210-
if self._formatting:
2211-
formatter = get_formatter(self._formatting.format_type, features=self.features)
2212-
format_dict = (
2213-
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
2214-
)
2215-
else:
2216-
format_dict = None
2217-
22182224
if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
2225+
formatter = get_formatter(self._formatting.format_type, features=self.features)
22192226
if ex_iterable.iter_arrow:
22202227
iterator = ex_iterable.iter_arrow()
22212228
else:
@@ -2225,13 +2232,8 @@ def __iter__(self):
22252232
return
22262233

22272234
for key, example in ex_iterable:
2228-
if self.features and not ex_iterable.is_typed:
2229-
# `IterableDataset` automatically fills missing columns with None.
2230-
# This is done with `_apply_feature_types_on_example`.
2231-
example = _apply_feature_types_on_example(
2232-
example, self.features, token_per_repo_id=self._token_per_repo_id
2233-
)
2234-
yield format_dict(example) if format_dict else example
2235+
# no need to format thanks to FormattedExamplesIterable
2236+
yield example
22352237

22362238
def iter(self, batch_size: int, drop_last_batch: bool = False):
22372239
"""Iterate through the batches of size `batch_size`.
@@ -2244,9 +2246,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
22442246

22452247
if self._formatting:
22462248
formatter = get_formatter(self._formatting.format_type, features=self.features)
2247-
format_dict = (
2248-
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
2249-
)
2249+
format_dict = formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else None
22502250
else:
22512251
format_dict = None
22522252

@@ -2267,10 +2267,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
22672267
if drop_last_batch and len(examples) < batch_size: # ignore last batch
22682268
return
22692269
batch = _examples_to_batch(examples)
2270-
if self.features and not ex_iterable.is_typed:
2271-
# `IterableDataset` automatically fills missing columns with None.
2272-
# This is done with `_apply_feature_types_on_batch`.
2273-
batch = _apply_feature_types_on_batch(batch, self.features, token_per_repo_id=self._token_per_repo_id)
2270+
# we need to format here in case we need to stack tensors together
22742271
yield format_dict(batch) if format_dict else batch
22752272

22762273
@staticmethod
@@ -3241,7 +3238,13 @@ def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableData
32413238
def batch_fn(unbatched):
32423239
return {k: [v] for k, v in unbatched.items()}
32433240

3244-
return self.map(batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch)
3241+
if self.features:
3242+
features = Features({col: [feature] for col, feature in self.features.items()})
3243+
else:
3244+
features = None
3245+
return self.map(
3246+
batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch, features=features
3247+
)
32453248

32463249

32473250
def _concatenate_iterable_datasets(

tests/test_iterable_dataset.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,13 @@ def test_mapped_examples_iterable_drop_last_batch(n, func, batched, batch_size):
505505
next(iter(ex_iterable))
506506

507507

508+
def _wrap_async(func, *args, **kwargs):
509+
async def wrapped_func(*args, **kwargs):
510+
return func(*args, **kwargs)
511+
512+
return wrapped_func
513+
514+
508515
@pytest.mark.parametrize(
509516
"n, func, batched, batch_size",
510517
[
@@ -519,10 +526,11 @@ def test_mapped_examples_iterable_drop_last_batch(n, func, batched, batch_size):
519526
(5, lambda x, indices: {"id+idx": [i + j for i, j in zip(x["id"], indices)]}, True, -1), # same with bs<=0
520527
],
521528
)
522-
def test_mapped_examples_iterable_with_indices(n, func, batched, batch_size):
529+
@pytest.mark.parametrize("wrapper", [lambda x: x, _wrap_async])
530+
def test_mapped_examples_iterable_with_indices(n, func, batched, batch_size, wrapper):
523531
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
524532
ex_iterable = MappedExamplesIterable(
525-
base_ex_iterable, func, batched=batched, batch_size=batch_size, with_indices=True
533+
base_ex_iterable, wrapper(func), batched=batched, batch_size=batch_size, with_indices=True
526534
)
527535
all_examples = [x for _, x in generate_examples_fn(n=n)]
528536
if batched is False:
@@ -2454,3 +2462,15 @@ def test_iterable_dataset_batch():
24542462
assert len(batches[2]["text"]) == 2
24552463
assert batches[2]["id"] == [8, 9]
24562464
assert batches[2]["text"] == ["Text 8", "Text 9"]
2465+
2466+
# Test with features
2467+
batched_ds = ds._resolve_features().batch(batch_size=3)
2468+
batches = list(batched_ds)
2469+
2470+
assert batched_ds.features is not None
2471+
assert len(batches) == 4 # 3 full batches and 1 partial batch
2472+
for i, batch in enumerate(batches[:1]):
2473+
assert len(batch["id"]) == 3
2474+
assert len(batch["text"]) == 3
2475+
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
2476+
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

0 commit comments

Comments
 (0)