8
8
import torch
9
9
import codecs
10
10
import string
11
+ from urllib .error import URLError
11
12
from .utils import download_url , download_and_extract_archive , extract_archive , \
12
13
makedir_exist_ok , verify_str_arg
13
14
@@ -29,11 +30,16 @@ class MNIST(VisionDataset):
29
30
target and transforms it.
30
31
"""
31
32
33
+ mirrors = [
34
+ 'http://yann.lecun.com/exdb/mnist/' ,
35
+ 'https://ossci-datasets.s3.amazonaws.com/mnist/' ,
36
+ ]
37
+
32
38
resources = [
33
- ("http://yann.lecun.com/exdb/mnist/ train-images-idx3-ubyte.gz" , "f68b3c2dcbeaaa9fbdd348bbdeb94873" ),
34
- ("http://yann.lecun.com/exdb/mnist/ train-labels-idx1-ubyte.gz" , "d53e105ee54ea40749a09fcbcd1e9432" ),
35
- ("http://yann.lecun.com/exdb/mnist/ t10k-images-idx3-ubyte.gz" , "9fb629c4189551a2d022fa330f9573f3" ),
36
- ("http://yann.lecun.com/exdb/mnist/ t10k-labels-idx1-ubyte.gz" , "ec29112dd5afa0611ce80d1b7f02629c" )
39
+ ("train-images-idx3-ubyte.gz" , "f68b3c2dcbeaaa9fbdd348bbdeb94873" ),
40
+ ("train-labels-idx1-ubyte.gz" , "d53e105ee54ea40749a09fcbcd1e9432" ),
41
+ ("t10k-images-idx3-ubyte.gz" , "9fb629c4189551a2d022fa330f9573f3" ),
42
+ ("t10k-labels-idx1-ubyte.gz" , "ec29112dd5afa0611ce80d1b7f02629c" )
37
43
]
38
44
39
45
training_file = 'training.pt'
@@ -133,9 +139,26 @@ def download(self):
133
139
makedir_exist_ok (self .processed_folder )
134
140
135
141
# download files
136
- for url , md5 in self .resources :
137
- filename = url .rpartition ('/' )[2 ]
138
- download_and_extract_archive (url , download_root = self .raw_folder , filename = filename , md5 = md5 )
142
+ for filename , md5 in self .resources :
143
+ for mirror in self .mirrors :
144
+ url = "{}{}" .format (mirror , filename )
145
+ try :
146
+ print ("Downloading {}" .format (url ))
147
+ download_and_extract_archive (
148
+ url , download_root = self .raw_folder ,
149
+ filename = filename ,
150
+ md5 = md5
151
+ )
152
+ except URLError as error :
153
+ print (
154
+ "Failed to download (trying next):\n {}" .format (error )
155
+ )
156
+ continue
157
+ finally :
158
+ print ()
159
+ break
160
+ else :
161
+ raise RuntimeError ("Error downloading {}" .format (filename ))
139
162
140
163
# process and save as torch files
141
164
print ('Processing...' )
0 commit comments