1
1
import os
2
2
from dataclasses import dataclass , field
3
- from typing import TYPE_CHECKING , Any , ClassVar , Optional , Union
3
+ from typing import TYPE_CHECKING , Any , ClassVar , Optional , TypedDict , Union
4
4
5
5
import numpy as np
6
6
import pyarrow as pa
18
18
from .features import FeatureType
19
19
20
20
21
+ class Example (TypedDict ):
22
+ path : Optional [str ]
23
+ bytes : Optional [bytes ]
24
+
25
+
21
26
@dataclass
22
27
class Video :
23
28
"""
@@ -66,7 +71,7 @@ class Video:
66
71
def __call__ (self ):
67
72
return self .pa_type
68
73
69
- def encode_example (self , value : Union [str , bytes , dict , np .ndarray , "VideoReader" ]) -> dict :
74
+ def encode_example (self , value : Union [str , bytes , Example , np .ndarray , "VideoReader" ]) -> Example :
70
75
"""Encode example into a format for Arrow.
71
76
72
77
Args:
@@ -92,21 +97,29 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader
92
97
elif isinstance (value , np .ndarray ):
93
98
# convert the video array to bytes
94
99
return encode_np_array (value )
95
- elif VideoReader and isinstance (value , VideoReader ):
100
+ elif VideoReader is not None and isinstance (value , VideoReader ):
96
101
# convert the torchvision video reader to bytes
97
102
return encode_torchvision_video (value )
98
- elif value .get ("path" ) is not None and os .path .isfile (value ["path" ]):
99
- # we set "bytes": None to not duplicate the data if they're already available locally
100
- return {"bytes" : None , "path" : value .get ("path" )}
101
- elif value .get ("bytes" ) is not None or value .get ("path" ) is not None :
102
- # store the video bytes, and path is used to infer the video format using the file extension
103
- return {"bytes" : value .get ("bytes" ), "path" : value .get ("path" )}
103
+ elif isinstance (value , dict ):
104
+ path , bytes_ = value .get ("path" ), value .get ("bytes" )
105
+ if path is not None and os .path .isfile (path ):
106
+ # we set "bytes": None to not duplicate the data if they're already available locally
107
+ return {"bytes" : None , "path" : path }
108
+ elif bytes_ is not None or path is not None :
109
+ # store the video bytes, and path is used to infer the video format using the file extension
110
+ return {"bytes" : bytes_ , "path" : path }
111
+ else :
112
+ raise ValueError (
113
+ f"A video sample should have one of 'path' or 'bytes' but they are missing or None in { value } ."
114
+ )
104
115
else :
105
- raise ValueError (
106
- f"A video sample should have one of 'path' or 'bytes' but they are missing or None in { value } ."
107
- )
116
+ raise TypeError (f"Unsupported encode_example type: { type (value )} " )
108
117
109
- def decode_example (self , value : dict , token_per_repo_id = None ) -> "VideoReader" :
118
+ def decode_example (
119
+ self ,
120
+ value : Union [str , Example ],
121
+ token_per_repo_id : Optional [dict [str , Union [bool , str ]]] = None ,
122
+ ) -> "VideoReader" :
110
123
"""Decode example video file into video data.
111
124
112
125
Args:
@@ -136,15 +149,18 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader":
136
149
if token_per_repo_id is None :
137
150
token_per_repo_id = {}
138
151
139
- path , bytes_ = value ["path" ], value ["bytes" ]
152
+ if isinstance (value , str ):
153
+ path , bytes_ = value , None
154
+ else :
155
+ path , bytes_ = value ["path" ], value ["bytes" ]
156
+
140
157
if bytes_ is None :
141
158
if path is None :
142
159
raise ValueError (f"A video should have one of 'path' or 'bytes' but both are None in { value } ." )
160
+ elif is_local_path (path ):
161
+ video = VideoReader (path )
143
162
else :
144
- if is_local_path (path ):
145
- video = VideoReader (path )
146
- else :
147
- video = hf_video_reader (path , token_per_repo_id = token_per_repo_id )
163
+ video = hf_video_reader (path , token_per_repo_id = token_per_repo_id )
148
164
else :
149
165
video = VideoReader (bytes_ )
150
166
video ._hf_encoded = {"path" : path , "bytes" : bytes_ }
@@ -215,7 +231,7 @@ def video_to_bytes(video: "VideoReader") -> bytes:
215
231
raise NotImplementedError ()
216
232
217
233
218
- def encode_torchvision_video (video : "VideoReader" ) -> dict :
234
+ def encode_torchvision_video (video : "VideoReader" ) -> Example :
219
235
if hasattr (video , "_hf_encoded" ):
220
236
return video ._hf_encoded
221
237
else :
@@ -224,7 +240,7 @@ def encode_torchvision_video(video: "VideoReader") -> dict:
224
240
)
225
241
226
242
227
- def encode_np_array (array : np .ndarray ) -> dict :
243
+ def encode_np_array (array : np .ndarray ) -> Example :
228
244
raise NotImplementedError ()
229
245
230
246
@@ -235,7 +251,7 @@ def encode_np_array(array: np.ndarray) -> dict:
235
251
236
252
237
253
def hf_video_reader (
238
- path : str , token_per_repo_id : Optional [dict [str , str ]] = None , stream : str = "video"
254
+ path : str , token_per_repo_id : Optional [dict [str , Union [ bool , str ] ]] = None , stream : str = "video"
239
255
) -> "VideoReader" :
240
256
import av
241
257
from torchvision import get_video_backend
@@ -246,11 +262,8 @@ def hf_video_reader(
246
262
token_per_repo_id = {}
247
263
source_url = path .split ("::" )[- 1 ]
248
264
pattern = config .HUB_DATASETS_URL if source_url .startswith (config .HF_ENDPOINT ) else config .HUB_DATASETS_HFFS_URL
249
- try :
250
- repo_id = string_to_dict (source_url , pattern )["repo_id" ]
251
- token = token_per_repo_id .get (repo_id )
252
- except ValueError :
253
- token = None
265
+ source_url_fields = string_to_dict (source_url , pattern )
266
+ token = token_per_repo_id .get (source_url_fields ["repo_id" ]) if source_url_fields is not None else None
254
267
download_config = DownloadConfig (token = token )
255
268
f = xopen (path , "rb" , download_config = download_config )
256
269
0 commit comments