Skip to content

Commit 6bdaf36

Browse files
committed
fix mnist
1 parent 89224e5 commit 6bdaf36

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/base/datasets.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ class MNIST(Dataset):
4949
cache_folder_name = 'complete'
5050

5151
def __init__(self, root: str = PATH_DATASETS, train: bool = True,
52-
normalize: tuple = (0.5, 1.0), download: bool = False):
52+
normalize: tuple = (0.5, 1.0), download: bool = True):
5353
super().__init__()
5454
self.root = root
5555
self.train = train # training set or test set
5656
self.normalize = normalize
5757

58-
self.prepare_data(download)
58+
self.prepare_data()
5959

6060
if not self._check_exists(self.cached_folder_path):
6161
raise RuntimeError('Dataset not found.')
@@ -85,7 +85,7 @@ def _check_exists(self, data_folder: str) -> bool:
8585
existing = existing and os.path.isfile(os.path.join(data_folder, fname))
8686
return existing
8787

88-
def prepare_data(self):
88+
def prepare_data(self, download: bool):
8989
self._download(self.cached_folder_path)
9090

9191
def _download(self, data_folder: str) -> None:
@@ -171,7 +171,7 @@ def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor,
171171
targets = full_targets[indexes]
172172
return data, targets
173173

174-
def prepare_data(self) -> None:
174+
def prepare_data(self, download: bool) -> None:
175175
if self._check_exists(self.cached_folder_path):
176176
return
177177
self._download(super().cached_folder_path)

0 commit comments

Comments
 (0)