Skip to content

Commit 09e5a87

Browse files
Adrian WälchliBorda
authored and
akarnachev
committed
Add MNIST dataset & drop torchvision dep. from tests (Lightning-AI#986)
* added custom mnist without torchvision dep * move files so it does not conflict with mnist gitignore * mock torchvision for tests * fix line too long * fix line too long * fix "module level import not at top of file" warning * move mock imports to __init__.py * simplify MNIST a lot and download directly the .pt files * further simplify and clean up mnist * revert import overrides * make as before * drop PIL requirement * move mnist.py to datasets subfolder * use logging instead of print * choose same name as in torchvision * remove torchvision and pillow also from yml file * refactor if train Co-Authored-By: Jirka Borovec <[email protected]> * capitalized class attr * moved mnist to models * re-added datsets ignore * better name for file variable * Update mnist.py * move dataset classes to datasets.py * new line * update * update * fix automerge * move to base folder * adapt testingmnist to new mnist base class * remove temporal fix * fix datatype * remove old testingmnist * readable * fix import * fix whitespace * docstring Co-Authored-By: Jirka Borovec <[email protected]> * Update tests/base/datasets.py Co-Authored-By: Jirka Borovec <[email protected]> * changelog * added types * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <[email protected]> * exist->isfile Co-Authored-By: Jirka Borovec <[email protected]> * index -> idx * temporary fix for trains error * better changelog message Co-authored-by: Jirka Borovec <[email protected]>
1 parent 3d65e2e commit 09e5a87

8 files changed

+115
-40
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ docs/source/pl_examples*.rst
1717
docs/source/pytorch_lightning*.rst
1818
docs/source/tests*.rst
1919
docs/source/*.md
20-
tests/tests/
2120

2221
# Byte-compiled / optimized / DLL files
2322
__pycache__/

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131
### Removed
3232

3333
- Removed duplicated module `pytorch_lightning.utilities.arg_parse` for loading CLI arguments ([#1167](https://github.com/PyTorchLightning/pytorch-lightning/issues/1167))
34+
- Dropped `torchvision` dependency in tests and added own MNIST dataset class instead ([#986](https://github.com/PyTorchLightning/pytorch-lightning/issues/986))
3435

3536
### Fixed
3637

environment.yml

-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ dependencies:
1515
- future>=0.17.1
1616

1717
# For dev and testing
18-
- torchvision>=0.4.0
1918
- tox
2019
- coverage
2120
- codecov
@@ -26,7 +25,6 @@ dependencies:
2625
- autopep8
2726
- check-manifest
2827
- twine==1.13.0
29-
- pillow<7.0.0
3028

3129
- pip:
3230
- test-tube>=0.7.5

tests/base/datasets.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import logging
2+
import os
3+
import urllib.request
4+
from typing import Tuple
5+
6+
import torch
7+
from torch import Tensor
8+
from torch.utils.data import Dataset
9+
10+
11+
class MNIST(Dataset):
12+
"""
13+
Customized `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset for testing Pytorch Lightning
14+
without the torchvision dependency.
15+
16+
Part of the code was copied from
17+
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py
18+
19+
Args:
20+
root: Root directory of dataset where ``MNIST/processed/training.pt``
21+
and ``MNIST/processed/test.pt`` exist.
22+
train: If ``True``, creates dataset from ``training.pt``,
23+
otherwise from ``test.pt``.
24+
normalize: mean and std deviation of the MNIST dataset.
25+
download: If true, downloads the dataset from the internet and
26+
puts it in root directory. If dataset is already downloaded, it is not
27+
downloaded again.
28+
"""
29+
30+
RESOURCES = (
31+
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
32+
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
33+
)
34+
35+
TRAIN_FILE_NAME = 'training.pt'
36+
TEST_FILE_NAME = 'test.pt'
37+
38+
def __init__(self, root: str, train: bool = True, normalize: tuple = (0.5, 1.0), download: bool = False):
39+
super(MNIST, self).__init__()
40+
self.root = root
41+
self.train = train # training set or test set
42+
self.normalize = normalize
43+
44+
if download:
45+
self.download()
46+
47+
if not self._check_exists():
48+
raise RuntimeError('Dataset not found.')
49+
50+
data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
51+
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
52+
53+
def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
54+
img = self.data[idx].float().unsqueeze(0)
55+
target = int(self.targets[idx])
56+
57+
if self.normalize is not None:
58+
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])
59+
60+
return img, target
61+
62+
def __len__(self) -> int:
63+
return len(self.data)
64+
65+
@property
66+
def processed_folder(self) -> str:
67+
return os.path.join(self.root, 'MNIST', 'processed')
68+
69+
def _check_exists(self) -> bool:
70+
train_file = os.path.join(self.processed_folder, self.TRAIN_FILE_NAME)
71+
test_file = os.path.join(self.processed_folder, self.TEST_FILE_NAME)
72+
return os.path.isfile(train_file) and os.path.isfile(test_file)
73+
74+
def download(self) -> None:
75+
"""Download the MNIST data if it doesn't exist in processed_folder already."""
76+
77+
if self._check_exists():
78+
return
79+
80+
os.makedirs(self.processed_folder, exist_ok=True)
81+
82+
for url in self.RESOURCES:
83+
logging.info(f'Downloading {url}')
84+
fpath = os.path.join(self.processed_folder, os.path.basename(url))
85+
urllib.request.urlretrieve(url, fpath)
86+
87+
88+
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
89+
tensor = tensor.clone()
90+
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
91+
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
92+
tensor.sub_(mean).div_(std)
93+
return tensor
94+
95+
96+
class TestingMNIST(MNIST):
97+
98+
def __init__(self, root, train=True, normalize=(0.5, 1.0), download=False, num_samples=8000):
99+
super().__init__(
100+
root,
101+
train=train,
102+
normalize=normalize,
103+
download=download
104+
)
105+
# take just a subset of MNIST dataset
106+
self.data = self.data[:num_samples]
107+
self.targets = self.targets[:num_samples]

tests/base/debug.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
22
from torch.nn import functional as F
33
from torch.utils.data import DataLoader
4-
from torchvision.datasets import MNIST
54

65
import pytorch_lightning as pl
6+
from tests.base.datasets import MNIST
77

88

99
# from test_models import assert_ok_test_acc, load_model, \

tests/base/models.py

+5-33
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch.nn.functional as F
88
from torch import optim
99
from torch.utils.data import DataLoader
10-
from torchvision import transforms
11-
from torchvision.datasets import MNIST
10+
11+
from tests.base.datasets import TestingMNIST
1212

1313
try:
1414
from test_tube import HyperOptArgumentParser
@@ -18,29 +18,6 @@
1818

1919
from pytorch_lightning.core.lightning import LightningModule
2020

21-
# TODO: remove after getting own MNIST
22-
# TEMPORAL FIX, https://github.com/pytorch/vision/issues/1938
23-
import urllib.request
24-
opener = urllib.request.build_opener()
25-
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
26-
urllib.request.install_opener(opener)
27-
28-
29-
class TestingMNIST(MNIST):
30-
31-
def __init__(self, root, train=True, transform=None, target_transform=None,
32-
download=False, num_samples=8000):
33-
super().__init__(
34-
root,
35-
train=train,
36-
transform=transform,
37-
target_transform=target_transform,
38-
download=download
39-
)
40-
# take just a subset of MNIST dataset
41-
self.data = self.data[:num_samples]
42-
self.targets = self.targets[:num_samples]
43-
4421

4522
class DictHparamsModel(LightningModule):
4623

@@ -61,8 +38,7 @@ def configure_optimizers(self):
6138
return torch.optim.Adam(self.parameters(), lr=0.02)
6239

6340
def train_dataloader(self):
64-
return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True,
65-
transform=transforms.ToTensor()), batch_size=32)
41+
return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True), batch_size=32)
6642

6743

6844
class TestModelBase(LightningModule):
@@ -178,17 +154,13 @@ def configure_optimizers(self):
178154
return [optimizer], [scheduler]
179155

180156
def prepare_data(self):
181-
transform = transforms.Compose([transforms.ToTensor(),
182-
transforms.Normalize((0.5,), (1.0,))])
183157
_ = TestingMNIST(root=self.hparams.data_root, train=True,
184-
transform=transform, download=True, num_samples=2000)
158+
download=True, num_samples=2000)
185159

186160
def _dataloader(self, train):
187161
# init data generators
188-
transform = transforms.Compose([transforms.ToTensor(),
189-
transforms.Normalize((0.5,), (1.0,))])
190162
dataset = TestingMNIST(root=self.hparams.data_root, train=train,
191-
transform=transform, download=False, num_samples=2000)
163+
download=False, num_samples=2000)
192164

193165
# when using multi-node we need to add the datasampler
194166
batch_size = self.hparams.batch_size

tests/datasets/__init__.py

Whitespace-only changes.

tests/requirements.txt

+1-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
-r ../requirements-extra.txt
33

44
# extended list of dependencies dor development and run lint and tests
5-
torchvision>=0.4.0, < 0.5 # the 0.5. has some issues with torch JIT
65
tox
76
coverage
87
codecov
@@ -11,5 +10,4 @@ pytest-cov
1110
pytest-flake8
1211
flake8
1312
check-manifest
14-
twine==1.13.0
15-
pillow<7.0.0
13+
twine==1.13.0

0 commit comments

Comments
 (0)