Skip to content

Commit b9626de

Browse files
committed
fix mnist
1 parent 686aa34 commit b9626de

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

tests/base/datasets.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class MNIST(Dataset):
4949
cache_folder_name = 'complete'
5050

5151
def __init__(self, root: str = PATH_DATASETS, train: bool = True,
52-
normalize: tuple = (0.5, 1.0), download: bool = False):
52+
normalize: tuple = (0.5, 1.0), download: bool = True):
5353
super().__init__()
5454
self.root = root
5555
self.train = train # training set or test set
@@ -86,8 +86,7 @@ def _check_exists(self, data_folder: str) -> bool:
8686
return existing
8787

8888
def prepare_data(self, download: bool):
89-
if download:
90-
self._download(self.cached_folder_path)
89+
self._download(self.cached_folder_path)
9190

9291
def _download(self, data_folder: str) -> None:
9392
"""Download the MNIST data if it doesn't exist in cached_folder_path already."""
@@ -175,8 +174,7 @@ def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor,
175174
def prepare_data(self, download: bool) -> None:
176175
if self._check_exists(self.cached_folder_path):
177176
return
178-
if download:
179-
self._download(super().cached_folder_path)
177+
self._download(super().cached_folder_path)
180178

181179
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
182180
data, targets = torch.load(os.path.join(super().cached_folder_path, fname))

0 commit comments

Comments
 (0)