Skip to content

Commit 0dbc7d8

Browse files
pmeierseemetherefmassa
authored
Fix MNIST download for minor release (#3559)
* datasets: Fallback to our own mirrors for mnist (#3544) We are experiencing 403s when trying to download from the main mnist site so lets fallback to our own mirror on failure. Signed-off-by: Eli Uriegas <[email protected]> Co-authored-by: Francisco Massa <[email protected]> * Fix (Fashion|K)MNIST download and MNIST download test (#3557) * add mirrors to (Fashion|K)MNIST * fix download tests for MNIST Co-authored-by: Eli Uriegas <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
1 parent 01dfa8e commit 0dbc7d8

File tree

2 files changed

+48
-20
lines changed

2 files changed

+48
-20
lines changed

test/test_datasets_download.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ def voc():
249249

250250

251251
def mnist():
252-
return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
252+
with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
253+
return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
253254

254255

255256
def fashion_mnist():

torchvision/datasets/mnist.py

+46-19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gzip
1111
import lzma
1212
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
13+
from urllib.error import URLError
1314
from .utils import download_url, download_and_extract_archive, extract_archive, \
1415
verify_str_arg
1516

@@ -31,11 +32,16 @@ class MNIST(VisionDataset):
3132
target and transforms it.
3233
"""
3334

35+
mirrors = [
36+
'http://yann.lecun.com/exdb/mnist/',
37+
'https://ossci-datasets.s3.amazonaws.com/mnist/',
38+
]
39+
3440
resources = [
35-
("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
36-
("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
37-
("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
38-
("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
41+
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
42+
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
43+
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
44+
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
3945
]
4046

4147
training_file = 'training.pt'
@@ -141,9 +147,26 @@ def download(self) -> None:
141147
os.makedirs(self.processed_folder, exist_ok=True)
142148

143149
# download files
144-
for url, md5 in self.resources:
145-
filename = url.rpartition('/')[2]
146-
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
150+
for filename, md5 in self.resources:
151+
for mirror in self.mirrors:
152+
url = "{}{}".format(mirror, filename)
153+
try:
154+
print("Downloading {}".format(url))
155+
download_and_extract_archive(
156+
url, download_root=self.raw_folder,
157+
filename=filename,
158+
md5=md5
159+
)
160+
except URLError as error:
161+
print(
162+
"Failed to download (trying next):\n{}".format(error)
163+
)
164+
continue
165+
finally:
166+
print()
167+
break
168+
else:
169+
raise RuntimeError("Error downloading {}".format(filename))
147170

148171
# process and save as torch files
149172
print('Processing...')
@@ -183,15 +206,15 @@ class FashionMNIST(MNIST):
183206
target_transform (callable, optional): A function/transform that takes in the
184207
target and transforms it.
185208
"""
209+
mirrors = [
210+
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
211+
]
212+
186213
resources = [
187-
("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
188-
"8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
189-
("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
190-
"25c81989df183df01b3e8a0aad5dffbe"),
191-
("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
192-
"bef4ecab320f06d8554ea6380940ec79"),
193-
("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
194-
"bb300cfdad3c16e7a12a480ee83cd310")
214+
("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
215+
("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
216+
("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
217+
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310")
195218
]
196219
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
197220
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
@@ -213,11 +236,15 @@ class KMNIST(MNIST):
213236
target_transform (callable, optional): A function/transform that takes in the
214237
target and transforms it.
215238
"""
239+
mirrors = [
240+
"http://codh.rois.ac.jp/kmnist/dataset/kmnist/"
241+
]
242+
216243
resources = [
217-
("http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
218-
("http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
219-
("http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
220-
("http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134")
244+
("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
245+
("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
246+
("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
247+
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134")
221248
]
222249
classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']
223250

0 commit comments

Comments
 (0)