@@ -49,13 +49,13 @@ class MNIST(Dataset):
49
49
cache_folder_name = 'complete'
50
50
51
51
def __init__ (self , root : str = PATH_DATASETS , train : bool = True ,
52
- normalize : tuple = (0.5 , 1.0 ), download : bool = False ):
52
+ normalize : tuple = (0.5 , 1.0 ), download : bool = True ):
53
53
super ().__init__ ()
54
54
self .root = root
55
55
self .train = train # training set or test set
56
56
self .normalize = normalize
57
57
58
- self .prepare_data (download )
58
+ self .prepare_data ()
59
59
60
60
if not self ._check_exists (self .cached_folder_path ):
61
61
raise RuntimeError ('Dataset not found.' )
@@ -85,7 +85,7 @@ def _check_exists(self, data_folder: str) -> bool:
85
85
existing = existing and os .path .isfile (os .path .join (data_folder , fname ))
86
86
return existing
87
87
88
- def prepare_data (self ):
88
+ def prepare_data (self , download : bool ):
89
89
self ._download (self .cached_folder_path )
90
90
91
91
def _download (self , data_folder : str ) -> None :
@@ -171,7 +171,7 @@ def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor,
171
171
targets = full_targets [indexes ]
172
172
return data , targets
173
173
174
- def prepare_data (self ) -> None :
174
+ def prepare_data (self , download : bool ) -> None :
175
175
if self ._check_exists (self .cached_folder_path ):
176
176
return
177
177
self ._download (super ().cached_folder_path )
0 commit comments