@@ -1076,15 +1076,17 @@ def _iter(self):
1076
1076
num_examples_to_skip = 0
1077
1077
iterator = iter (self .ex_iterable )
1078
1078
1079
+ # We use the same logic as in Dataset.map, but with less features/formatting
1080
+ # since they're handled by FormattedExamplesIterable
1081
+
1079
1082
if self .formatting :
1080
1083
formatter = get_formatter (self .formatting .format_type )
1081
- format_dict = (
1082
- formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else cast_to_python_objects
1083
- )
1084
+ format_dict = formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else None
1084
1085
else :
1085
1086
format_dict = None
1086
1087
1087
1088
def iter_batched_inputs ():
1089
+ nonlocal current_idx
1088
1090
for key , example in iterator :
1089
1091
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
1090
1092
iterator_batch = (
@@ -1104,17 +1106,21 @@ def iter_batched_inputs():
1104
1106
): # ignore last batch
1105
1107
return
1106
1108
batch = _examples_to_batch (examples )
1109
+ # we need to format here in case we need to stack tensors together
1107
1110
batch = format_dict (batch ) if format_dict else batch
1108
1111
indices = [current_idx + i for i in range (len (key_examples_list ))]
1112
+ current_idx += len (indices )
1109
1113
yield indices , (key , batch )
1110
1114
1111
1115
def iter_inputs ():
1116
+ nonlocal current_idx
1112
1117
for key , example in iterator :
1113
1118
# If not batched, we can apply the transform and yield the example directly
1114
1119
# first copy the example, since we might drop some keys
1115
1120
example = dict (example )
1116
- example = format_dict (example ) if format_dict else example
1117
- yield current_idx , (key , example )
1121
+ # no need to do formatting here
1122
+ current_idx += 1
1123
+ yield current_idx - 1 , (key , example )
1118
1124
1119
1125
def validate_function_output (processed_inputs ):
1120
1126
if self .batched and processed_inputs :
@@ -1147,17 +1153,7 @@ def prepare_outputs(key_example, inputs, processed_inputs):
1147
1153
if processed_inputs is key_example [1 ] and c in processed_inputs :
1148
1154
del processed_inputs [c ]
1149
1155
transformed_inputs = {** inputs , ** processed_inputs }
1150
- if self .features :
1151
- for c in self .features .keys ():
1152
- if c not in transformed_inputs :
1153
- transformed_inputs [c ] = (
1154
- [None ] * len (transformed_inputs [next (iter (processed_inputs ))]) if self .batched else None
1155
- )
1156
- transformed_inputs = (
1157
- self .features .decode_batch (transformed_inputs )
1158
- if self .batched
1159
- else self .features .decode_example (transformed_inputs )
1160
- )
1156
+ # no need to do features decoding here
1161
1157
return transformed_inputs
1162
1158
1163
1159
def apply_function (key_example , indices ):
@@ -1185,6 +1181,11 @@ def iter_outputs():
1185
1181
nonlocal tasks , loop
1186
1182
inputs_iterator = iter_batched_inputs () if self .batched else iter_inputs ()
1187
1183
if inspect .iscoroutinefunction (self .function ):
1184
+ if self ._state_dict :
1185
+ previous_state = self .ex_iterable .state_dict ()
1186
+ self ._state_dict ["previous_state" ] = previous_state
1187
+ previous_state_task = None
1188
+ previous_state_example_idx = self ._state_dict ["previous_state_example_idx" ]
1188
1189
indices : Union [list [int ], list [list [int ]]] = []
1189
1190
for i , key_example in inputs_iterator :
1190
1191
indices .append (i )
@@ -1198,42 +1199,57 @@ def iter_outputs():
1198
1199
done , pending = loop .run_until_complete (
1199
1200
asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
1200
1201
)
1202
+ if len (tasks ) >= 10 * config .MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL :
1203
+ loop .run_until_complete (tasks [0 ])
1201
1204
# yield finished tasks
1202
1205
while tasks and tasks [0 ].done ():
1203
- yield indices .pop (0 ), tasks .pop (0 ).result ()
1206
+ i , task = indices .pop (0 ), tasks .pop (0 )
1207
+ yield i , task .result ()
1208
+ if self ._state_dict and task is previous_state_task :
1209
+ self ._state_dict ["previous_state" ] = previous_state
1210
+ self ._state_dict ["num_examples_since_previous_state" ] = 0
1211
+ self ._state_dict ["previous_state_example_idx" ] = previous_state_example_idx
1212
+ previous_state , previous_state_task = None , None
1213
+ # checkpoint
1214
+ if self ._state_dict and previous_state_task is None and tasks :
1215
+ previous_state = self .ex_iterable .state_dict ()
1216
+ previous_state_task = tasks [- 1 ]
1217
+ previous_state_example_idx = current_idx
1204
1218
while tasks :
1205
1219
yield indices [0 ], loop .run_until_complete (tasks [0 ])
1206
1220
indices .pop (0 ), tasks .pop (0 )
1207
1221
else :
1208
- for i , key_example in inputs_iterator :
1209
- yield i , apply_function (key_example , i )
1210
-
1211
- try :
1212
- if self .batched :
1213
1222
if self ._state_dict :
1214
- self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
1215
- self ._state_dict ["num_examples_since_previous_state" ] = 0
1216
- self ._state_dict ["previous_state_example_idx" ] = current_idx
1217
- for key , transformed_batch in iter_outputs ():
1218
- # yield one example at a time from the transformed batch
1219
- for example in _batch_to_examples (transformed_batch ):
1220
- current_idx += 1
1221
- if self ._state_dict :
1222
- self ._state_dict ["num_examples_since_previous_state" ] += 1
1223
- if num_examples_to_skip > 0 :
1224
- num_examples_to_skip -= 1
1225
- continue
1226
- yield key , example
1227
- if self ._state_dict :
1223
+ if self .batched :
1228
1224
self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
1229
1225
self ._state_dict ["num_examples_since_previous_state" ] = 0
1230
1226
self ._state_dict ["previous_state_example_idx" ] = current_idx
1231
- else :
1232
- for key , transformed_example in iter_outputs ():
1233
- current_idx += 1
1227
+ for i , key_example in inputs_iterator :
1234
1228
if self ._state_dict :
1235
- self ._state_dict ["previous_state_example_idx" ] += 1
1236
- yield key , transformed_example
1229
+ if not self .batched :
1230
+ self ._state_dict ["previous_state_example_idx" ] = current_idx
1231
+ yield i , apply_function (key_example , i )
1232
+ if self ._state_dict :
1233
+ if self .batched :
1234
+ self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
1235
+ self ._state_dict ["num_examples_since_previous_state" ] = 0
1236
+ self ._state_dict ["previous_state_example_idx" ] = current_idx
1237
+
1238
+ try :
1239
+ outputs = iter_outputs ()
1240
+ if self .batched :
1241
+ outputs = (
1242
+ (key , transformed_example )
1243
+ for key , transformed_batch in outputs
1244
+ for transformed_example in _batch_to_examples (transformed_batch )
1245
+ )
1246
+ for key , transformed_example in outputs :
1247
+ if self ._state_dict and self ._state_dict ["previous_state" ] is not None :
1248
+ self ._state_dict ["num_examples_since_previous_state" ] += 1
1249
+ if num_examples_to_skip > 0 :
1250
+ num_examples_to_skip -= 1
1251
+ continue
1252
+ yield key , transformed_example
1237
1253
except (Exception , KeyboardInterrupt ):
1238
1254
if loop :
1239
1255
logger .debug (f"Canceling { len (tasks )} async tasks." )
@@ -1800,7 +1816,7 @@ def _init_state_dict(self) -> dict:
1800
1816
1801
1817
def __iter__ (self ):
1802
1818
if not self .formatting or self .formatting .is_table :
1803
- formatter = PythonFormatter ()
1819
+ formatter = PythonFormatter (features = self . _features if not self . ex_iterable . is_typed else None )
1804
1820
else :
1805
1821
formatter = get_formatter (
1806
1822
self .formatting .format_type ,
@@ -1817,15 +1833,17 @@ def __iter__(self):
1817
1833
format_dict = (
1818
1834
formatter .recursive_tensorize
1819
1835
if isinstance (formatter , TensorFormatter )
1820
- else cast_to_python_objects # cast in case features is None
1836
+ else None # cast in case features is None
1821
1837
)
1822
1838
for key , example in self .ex_iterable :
1823
1839
# don't apply feature types if already applied by ex_iterable (e.g. in case of chained with_format)
1824
1840
if self .features and not self .ex_iterable .is_typed :
1825
1841
example = _apply_feature_types_on_example (
1826
1842
example , self .features , token_per_repo_id = self .token_per_repo_id
1827
1843
)
1828
- yield key , format_dict (example )
1844
+ if format_dict :
1845
+ example = format_dict (example )
1846
+ yield key , example
1829
1847
1830
1848
def _iter_arrow (self ) -> Iterator [tuple [Key , pa .Table ]]:
1831
1849
if not self .features :
@@ -2049,7 +2067,7 @@ def __setstate__(self, d):
2049
2067
_maybe_add_torch_iterable_dataset_parent_class (self .__class__ )
2050
2068
2051
2069
def _head (self , n = 5 ):
2052
- return _examples_to_batch ( list (self .take ( n )))
2070
+ return next ( iter (self .iter ( batch_size = n )))
2053
2071
2054
2072
@property
2055
2073
def epoch (self ) -> int :
@@ -2111,15 +2129,8 @@ def _iter_pytorch(self):
2111
2129
if self ._starting_state_dict :
2112
2130
ex_iterable .load_state_dict (self ._starting_state_dict )
2113
2131
2114
- if self ._formatting :
2115
- formatter = get_formatter (self ._formatting .format_type , features = self .features )
2116
- format_dict = (
2117
- formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else cast_to_python_objects
2118
- )
2119
- else :
2120
- format_dict = None
2121
-
2122
2132
if self ._formatting and (ex_iterable .iter_arrow or self ._formatting .is_table ):
2133
+ formatter = get_formatter (self ._formatting .format_type , features = self .features )
2123
2134
if ex_iterable .iter_arrow :
2124
2135
iterator = ex_iterable .iter_arrow ()
2125
2136
else :
@@ -2129,13 +2140,8 @@ def _iter_pytorch(self):
2129
2140
return
2130
2141
else :
2131
2142
for key , example in ex_iterable :
2132
- if self .features and not ex_iterable .is_typed :
2133
- # `IterableDataset` automatically fills missing columns with None.
2134
- # This is done with `_apply_feature_types_on_example`.
2135
- example = _apply_feature_types_on_example (
2136
- example , self .features , token_per_repo_id = self ._token_per_repo_id
2137
- )
2138
- yield format_dict (example ) if format_dict else example
2143
+ # no need to format thanks to FormattedExamplesIterable
2144
+ yield example
2139
2145
logger .debug (
2140
2146
f"{ _log_prefix } dataloader worker#{ worker_info .id } , ': Finished iterating over { len (shards_indices )} /{ ex_iterable .num_shards } shards."
2141
2147
)
@@ -2191,6 +2197,14 @@ def _prepare_ex_iterable_for_iteration(
2191
2197
)
2192
2198
ex_iterable = StepExamplesIterable (ex_iterable , step = world_size , offset = rank )
2193
2199
2200
+ if self ._formatting or (self .features and ex_iterable .features != self .features ):
2201
+ ex_iterable = FormattedExamplesIterable (
2202
+ ex_iterable ,
2203
+ formatting = self ._formatting ,
2204
+ features = self .features ,
2205
+ token_per_repo_id = self ._token_per_repo_id ,
2206
+ )
2207
+
2194
2208
self ._state_dict = ex_iterable ._init_state_dict ()
2195
2209
if self ._starting_state_dict :
2196
2210
ex_iterable .load_state_dict (self ._starting_state_dict )
@@ -2207,15 +2221,8 @@ def __iter__(self):
2207
2221
return
2208
2222
2209
2223
ex_iterable = self ._prepare_ex_iterable_for_iteration ()
2210
- if self ._formatting :
2211
- formatter = get_formatter (self ._formatting .format_type , features = self .features )
2212
- format_dict = (
2213
- formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else cast_to_python_objects
2214
- )
2215
- else :
2216
- format_dict = None
2217
-
2218
2224
if self ._formatting and (ex_iterable .iter_arrow or self ._formatting .is_table ):
2225
+ formatter = get_formatter (self ._formatting .format_type , features = self .features )
2219
2226
if ex_iterable .iter_arrow :
2220
2227
iterator = ex_iterable .iter_arrow ()
2221
2228
else :
@@ -2225,13 +2232,8 @@ def __iter__(self):
2225
2232
return
2226
2233
2227
2234
for key , example in ex_iterable :
2228
- if self .features and not ex_iterable .is_typed :
2229
- # `IterableDataset` automatically fills missing columns with None.
2230
- # This is done with `_apply_feature_types_on_example`.
2231
- example = _apply_feature_types_on_example (
2232
- example , self .features , token_per_repo_id = self ._token_per_repo_id
2233
- )
2234
- yield format_dict (example ) if format_dict else example
2235
+ # no need to format thanks to FormattedExamplesIterable
2236
+ yield example
2235
2237
2236
2238
def iter (self , batch_size : int , drop_last_batch : bool = False ):
2237
2239
"""Iterate through the batches of size `batch_size`.
@@ -2244,9 +2246,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
2244
2246
2245
2247
if self ._formatting :
2246
2248
formatter = get_formatter (self ._formatting .format_type , features = self .features )
2247
- format_dict = (
2248
- formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else cast_to_python_objects
2249
- )
2249
+ format_dict = formatter .recursive_tensorize if isinstance (formatter , TensorFormatter ) else None
2250
2250
else :
2251
2251
format_dict = None
2252
2252
@@ -2267,10 +2267,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
2267
2267
if drop_last_batch and len (examples ) < batch_size : # ignore last batch
2268
2268
return
2269
2269
batch = _examples_to_batch (examples )
2270
- if self .features and not ex_iterable .is_typed :
2271
- # `IterableDataset` automatically fills missing columns with None.
2272
- # This is done with `_apply_feature_types_on_batch`.
2273
- batch = _apply_feature_types_on_batch (batch , self .features , token_per_repo_id = self ._token_per_repo_id )
2270
+ # we need to format here in case we need to stack tensors together
2274
2271
yield format_dict (batch ) if format_dict else batch
2275
2272
2276
2273
@staticmethod
@@ -3241,7 +3238,13 @@ def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableData
3241
3238
def batch_fn (unbatched ):
3242
3239
return {k : [v ] for k , v in unbatched .items ()}
3243
3240
3244
- return self .map (batch_fn , batched = True , batch_size = batch_size , drop_last_batch = drop_last_batch )
3241
+ if self .features :
3242
+ features = Features ({col : [feature ] for col , feature in self .features .items ()})
3243
+ else :
3244
+ features = None
3245
+ return self .map (
3246
+ batch_fn , batched = True , batch_size = batch_size , drop_last_batch = drop_last_batch , features = features
3247
+ )
3245
3248
3246
3249
3247
3250
def _concatenate_iterable_datasets (
0 commit comments