Skip to content

Commit 67ffdfb

Browse files
authored
Refactor string_to_dict to return None if there is no match instead of raising ValueError (#7435)
* Refactor string_to_dict to return None if there is no match instead of raising ValueError instead of having the pattern of using try-except to handle when there is no match, we can instead check if the return value is None; we can also assert that the return value should not be None if we know that should be true * Allow for source_url_fields to be None they can be local file paths here https://github.com/huggingface/datasets/actions/runs/13683185040/job/38380924390?pr=7435#step:10:9731
1 parent f693f4e commit 67ffdfb

File tree

8 files changed

+98
-65
lines changed

8 files changed

+98
-65
lines changed

src/datasets/arrow_dataset.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -3179,9 +3179,11 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31793179
del kwargs["shard"]
31803180
else:
31813181
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
3182-
assert None not in transformed_shards, (
3183-
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results"
3184-
)
3182+
if None in transformed_shards:
3183+
raise ValueError(
3184+
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at "
3185+
"least one worker failed to return its results"
3186+
)
31853187
logger.info(f"Concatenating {num_proc} shards")
31863188
result = _concatenate_map_style_datasets(transformed_shards)
31873189
# update fingerprint if the dataset changed
@@ -5328,7 +5330,7 @@ def _push_parquet_shards_to_hub(
53285330
max_shard_size: Optional[Union[int, str]] = None,
53295331
num_shards: Optional[int] = None,
53305332
embed_external_files: bool = True,
5331-
) -> tuple[str, str, int, int, list[str], int]:
5333+
) -> tuple[list[CommitOperationAdd], int, int]:
53325334
"""Pushes the dataset shards as Parquet files to the hub.
53335335
53345336
Returns:
@@ -5374,7 +5376,7 @@ def shards_with_embedded_external_files(shards: Iterator[Dataset]) -> Iterator[D
53745376
api = HfApi(endpoint=config.HF_ENDPOINT, token=token)
53755377

53765378
uploaded_size = 0
5377-
additions = []
5379+
additions: list[CommitOperationAdd] = []
53785380
for index, shard in hf_tqdm(
53795381
enumerate(shards),
53805382
desc="Uploading the dataset shards",
@@ -5559,8 +5561,9 @@ def push_to_hub(
55595561
# Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern)
55605562
# and delete old split shards (if they exist)
55615563
repo_with_dataset_card, repo_with_dataset_infos = False, False
5562-
deletions, deleted_size = [], 0
5563-
repo_splits = [] # use a list to keep the order of the splits
5564+
deletions: list[CommitOperationDelete] = []
5565+
deleted_size = 0
5566+
repo_splits: list[str] = [] # use a list to keep the order of the splits
55645567
repo_files_to_add = [addition.path_in_repo for addition in additions]
55655568
for repo_file in api.list_repo_tree(
55665569
repo_id=repo_id, revision=revision, repo_type="dataset", token=token, recursive=True
@@ -5579,10 +5582,10 @@ def push_to_hub(
55795582
elif fnmatch.fnmatch(
55805583
repo_file.rfilename, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*")
55815584
):
5582-
repo_split = string_to_dict(
5583-
repo_file.rfilename,
5584-
glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED),
5585-
)["split"]
5585+
pattern = glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED)
5586+
split_pattern_fields = string_to_dict(repo_file.rfilename, pattern)
5587+
assert split_pattern_fields is not None
5588+
repo_split = split_pattern_fields["split"]
55865589
if repo_split not in repo_splits:
55875590
repo_splits.append(repo_split)
55885591

src/datasets/data_files.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,16 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> di
264264
except FileNotFoundError:
265265
continue
266266
if len(data_files) > 0:
267-
splits: set[str] = {
268-
string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))["split"]
269-
for p in data_files
270-
}
267+
splits: set[str] = set()
268+
for p in data_files:
269+
p_parts = string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))
270+
assert p_parts is not None
271+
splits.add(p_parts["split"])
272+
271273
if any(not re.match(_split_re, split) for split in splits):
272274
raise ValueError(f"Split name should match '{_split_re}'' but got '{splits}'.")
273275
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
274-
splits - set(DEFAULT_SPLITS)
276+
splits - {str(split) for split in DEFAULT_SPLITS}
275277
)
276278
return {split: [split_pattern.format(split=split)] for split in sorted_splits}
277279
# then check the default patterns based on train/valid/test splits

src/datasets/dataset_dict.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1765,8 +1765,8 @@ def push_to_hub(
17651765
# Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern)
17661766
# and delete old split shards (if they exist)
17671767
repo_with_dataset_card, repo_with_dataset_infos = False, False
1768-
repo_splits = [] # use a list to keep the order of the splits
1769-
deletions = []
1768+
repo_splits: list[str] = [] # use a list to keep the order of the splits
1769+
deletions: list[CommitOperationDelete] = []
17701770
repo_files_to_add = [addition.path_in_repo for addition in additions]
17711771
for repo_file in api.list_repo_tree(
17721772
repo_id=repo_id,
@@ -1790,12 +1790,12 @@ def push_to_hub(
17901790
repo_file.rfilename,
17911791
PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*"),
17921792
):
1793-
repo_split = string_to_dict(
1794-
repo_file.rfilename,
1795-
glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED),
1796-
)["split"]
1793+
pattern = glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED)
1794+
split_pattern_fields = string_to_dict(repo_file.rfilename, pattern)
1795+
assert split_pattern_fields is not None
1796+
repo_split = split_pattern_fields["split"]
17971797
if repo_split not in repo_splits:
1798-
repo_splits.append(split)
1798+
repo_splits.append(repo_split)
17991799

18001800
# get the info from the README to update them
18011801
if repo_with_dataset_card:

src/datasets/features/audio.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,8 @@ def decode_example(
173173
pattern = (
174174
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL
175175
)
176-
try:
177-
repo_id = string_to_dict(source_url, pattern)["repo_id"]
178-
token = token_per_repo_id[repo_id]
179-
except (ValueError, KeyError):
180-
token = None
176+
source_url_fields = string_to_dict(source_url, pattern)
177+
token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
181178

182179
download_config = DownloadConfig(token=token)
183180
with xopen(path, "rb", download_config=download_config) as f:

src/datasets/features/image.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,10 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "PIL.Image.Imag
174174
if source_url.startswith(config.HF_ENDPOINT)
175175
else config.HUB_DATASETS_HFFS_URL
176176
)
177-
try:
178-
repo_id = string_to_dict(source_url, pattern)["repo_id"]
179-
token = token_per_repo_id.get(repo_id)
180-
except ValueError:
181-
token = None
177+
source_url_fields = string_to_dict(source_url, pattern)
178+
token = (
179+
token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
180+
)
182181
download_config = DownloadConfig(token=token)
183182
with xopen(path, "rb", download_config=download_config) as f:
184183
bytes_ = BytesIO(f.read())

src/datasets/features/video.py

+39-26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from dataclasses import dataclass, field
3-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
3+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypedDict, Union
44

55
import numpy as np
66
import pyarrow as pa
@@ -18,6 +18,11 @@
1818
from .features import FeatureType
1919

2020

21+
class Example(TypedDict):
22+
path: Optional[str]
23+
bytes: Optional[bytes]
24+
25+
2126
@dataclass
2227
class Video:
2328
"""
@@ -66,7 +71,7 @@ class Video:
6671
def __call__(self):
6772
return self.pa_type
6873

69-
def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader"]) -> dict:
74+
def encode_example(self, value: Union[str, bytes, Example, np.ndarray, "VideoReader"]) -> Example:
7075
"""Encode example into a format for Arrow.
7176
7277
Args:
@@ -92,21 +97,29 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader
9297
elif isinstance(value, np.ndarray):
9398
# convert the video array to bytes
9499
return encode_np_array(value)
95-
elif VideoReader and isinstance(value, VideoReader):
100+
elif VideoReader is not None and isinstance(value, VideoReader):
96101
# convert the torchvision video reader to bytes
97102
return encode_torchvision_video(value)
98-
elif value.get("path") is not None and os.path.isfile(value["path"]):
99-
# we set "bytes": None to not duplicate the data if they're already available locally
100-
return {"bytes": None, "path": value.get("path")}
101-
elif value.get("bytes") is not None or value.get("path") is not None:
102-
# store the video bytes, and path is used to infer the video format using the file extension
103-
return {"bytes": value.get("bytes"), "path": value.get("path")}
103+
elif isinstance(value, dict):
104+
path, bytes_ = value.get("path"), value.get("bytes")
105+
if path is not None and os.path.isfile(path):
106+
# we set "bytes": None to not duplicate the data if they're already available locally
107+
return {"bytes": None, "path": path}
108+
elif bytes_ is not None or path is not None:
109+
# store the video bytes, and path is used to infer the video format using the file extension
110+
return {"bytes": bytes_, "path": path}
111+
else:
112+
raise ValueError(
113+
f"A video sample should have one of 'path' or 'bytes' but they are missing or None in {value}."
114+
)
104115
else:
105-
raise ValueError(
106-
f"A video sample should have one of 'path' or 'bytes' but they are missing or None in {value}."
107-
)
116+
raise TypeError(f"Unsupported encode_example type: {type(value)}")
108117

109-
def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader":
118+
def decode_example(
119+
self,
120+
value: Union[str, Example],
121+
token_per_repo_id: Optional[dict[str, Union[bool, str]]] = None,
122+
) -> "VideoReader":
110123
"""Decode example video file into video data.
111124
112125
Args:
@@ -136,15 +149,18 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader":
136149
if token_per_repo_id is None:
137150
token_per_repo_id = {}
138151

139-
path, bytes_ = value["path"], value["bytes"]
152+
if isinstance(value, str):
153+
path, bytes_ = value, None
154+
else:
155+
path, bytes_ = value["path"], value["bytes"]
156+
140157
if bytes_ is None:
141158
if path is None:
142159
raise ValueError(f"A video should have one of 'path' or 'bytes' but both are None in {value}.")
160+
elif is_local_path(path):
161+
video = VideoReader(path)
143162
else:
144-
if is_local_path(path):
145-
video = VideoReader(path)
146-
else:
147-
video = hf_video_reader(path, token_per_repo_id=token_per_repo_id)
163+
video = hf_video_reader(path, token_per_repo_id=token_per_repo_id)
148164
else:
149165
video = VideoReader(bytes_)
150166
video._hf_encoded = {"path": path, "bytes": bytes_}
@@ -215,7 +231,7 @@ def video_to_bytes(video: "VideoReader") -> bytes:
215231
raise NotImplementedError()
216232

217233

218-
def encode_torchvision_video(video: "VideoReader") -> dict:
234+
def encode_torchvision_video(video: "VideoReader") -> Example:
219235
if hasattr(video, "_hf_encoded"):
220236
return video._hf_encoded
221237
else:
@@ -224,7 +240,7 @@ def encode_torchvision_video(video: "VideoReader") -> dict:
224240
)
225241

226242

227-
def encode_np_array(array: np.ndarray) -> dict:
243+
def encode_np_array(array: np.ndarray) -> Example:
228244
raise NotImplementedError()
229245

230246

@@ -235,7 +251,7 @@ def encode_np_array(array: np.ndarray) -> dict:
235251

236252

237253
def hf_video_reader(
238-
path: str, token_per_repo_id: Optional[dict[str, str]] = None, stream: str = "video"
254+
path: str, token_per_repo_id: Optional[dict[str, Union[bool, str]]] = None, stream: str = "video"
239255
) -> "VideoReader":
240256
import av
241257
from torchvision import get_video_backend
@@ -246,11 +262,8 @@ def hf_video_reader(
246262
token_per_repo_id = {}
247263
source_url = path.split("::")[-1]
248264
pattern = config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL
249-
try:
250-
repo_id = string_to_dict(source_url, pattern)["repo_id"]
251-
token = token_per_repo_id.get(repo_id)
252-
except ValueError:
253-
token = None
265+
source_url_fields = string_to_dict(source_url, pattern)
266+
token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
254267
download_config = DownloadConfig(token=token)
255268
f = xopen(path, "rb", download_config=download_config)
256269

src/datasets/utils/py_utils.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def glob_pattern_to_regex(pattern):
159159
)
160160

161161

162-
def string_to_dict(string: str, pattern: str) -> dict[str, str]:
162+
def string_to_dict(string: str, pattern: str) -> Optional[dict[str, str]]:
163163
"""Un-format a string using a python f-string pattern.
164164
From https://stackoverflow.com/a/36838374
165165
@@ -177,15 +177,14 @@ def string_to_dict(string: str, pattern: str) -> dict[str, str]:
177177
pattern (str): pattern formatted like a python f-string
178178
179179
Returns:
180-
Dict[str, str]: dictionary of variable -> value, retrieved from the input using the pattern
181-
182-
Raises:
183-
ValueError: if the string doesn't match the pattern
180+
Optional[dict[str, str]]: dictionary of variable -> value, retrieved from the input using the pattern, or
181+
`None` if the string does not match the pattern.
184182
"""
183+
pattern = re.sub(r"{([^:}]+)(?::[^}]+)?}", r"{\1}", pattern) # remove format specifiers, e.g. {rank:05d} -> {rank}
185184
regex = re.sub(r"{(.+?)}", r"(?P<_\1>.+)", pattern)
186185
result = re.search(regex, string)
187186
if result is None:
188-
raise ValueError(f"String {string} doesn't match the pattern {pattern}")
187+
return None
189188
values = list(result.groups())
190189
keys = re.findall(r"{(.+?)}", pattern)
191190
_dict = dict(zip(keys, values))

tests/test_py_utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import time
23
from dataclasses import dataclass
34
from multiprocessing import Pool
@@ -13,6 +14,7 @@
1314
asdict,
1415
iflatmap_unordered,
1516
map_nested,
17+
string_to_dict,
1618
temp_seed,
1719
temporary_assignment,
1820
zip_dict,
@@ -267,3 +269,21 @@ def test_iflatmap_unordered():
267269
assert out.count("a") == 2
268270
assert out.count("b") == 2
269271
assert len(out) == 4
272+
273+
274+
def test_string_to_dict():
275+
file_name = "dataset/cache-3b163736cf4505085d8b5f9b4c266c26.arrow"
276+
file_name_prefix, file_name_ext = os.path.splitext(file_name)
277+
278+
suffix_template = "_{rank:05d}_of_{num_proc:05d}"
279+
cache_file_name_pattern = file_name_prefix + suffix_template + file_name_ext
280+
281+
file_name_parts = string_to_dict(file_name, cache_file_name_pattern)
282+
assert file_name_parts is None
283+
284+
rank = 1
285+
num_proc = 2
286+
file_name = file_name_prefix + suffix_template.format(rank=rank, num_proc=num_proc) + file_name_ext
287+
file_name_parts = string_to_dict(file_name, cache_file_name_pattern)
288+
assert file_name_parts is not None
289+
assert file_name_parts == {"rank": f"{rank:05d}", "num_proc": f"{num_proc:05d}"}

0 commit comments

Comments
 (0)