Skip to content

Commit d85aa6d

Browse files
ABD-01ABD-01pmeierfmassa
authored
Added LFW Dataset (#4255)
* Added LFW Dataset * Added dataset to list in __init__.py * Updated lfw.py * Created a common superclass for people and pairs type datatsets * corrected the .download() method * Added docstrings and updated datasets.rst * Wrote tests for LFWPeople and LFWPairs * Resolved mypy error: Need type annotation for "data" * Updated inject_fake_data method for LFWPeople * Updated tests for LFW * Updated LFW tests and minor changes in lfw.py * Updated LFW * Added functionality for 10-fold validation view * Optimized the code so to replace repeated lines by method in super class * Updated LFWPeople to get classes from all lfw-names.txt rather than just the classes fron trainset * Updated lfw.py and tests * Updated inject_fake_data method to create 10fold fake data * Minor changes in docstring and extra_repr * resolved py lint errors * Added checksums for annotation files * Minor changes in test * Updated docstrings, defaults and minor changes in test * Removed 'os.path.exists' check Co-authored-by: ABD-01 <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
1 parent 4d92892 commit d85aa6d

File tree

4 files changed

+352
-1
lines changed

4 files changed

+352
-1
lines changed

docs/source/datasets.rst

+11
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,17 @@ KMNIST
147147

148148
.. autoclass:: KMNIST
149149

150+
LFW
151+
~~~~~
152+
153+
.. autoclass:: LFWPeople
154+
:members: __getitem__
155+
:special-members:
156+
157+
.. autoclass:: LFWPairs
158+
:members: __getitem__
159+
:special-members:
160+
150161
LSUN
151162
~~~~
152163

test/test_datasets.py

+82
Original file line numberDiff line numberDiff line change
@@ -1801,5 +1801,87 @@ def test_targets(self):
18011801
assert item[6] == i // 3
18021802

18031803

1804+
class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
1805+
DATASET_CLASS = datasets.LFWPeople
1806+
FEATURE_TYPES = (PIL.Image.Image, int)
1807+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
1808+
split=('10fold', 'train', 'test'),
1809+
image_set=('original', 'funneled', 'deepfunneled')
1810+
)
1811+
_IMAGES_DIR = {
1812+
"original": "lfw",
1813+
"funneled": "lfw_funneled",
1814+
"deepfunneled": "lfw-deepfunneled"
1815+
}
1816+
_file_id = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'}
1817+
1818+
def inject_fake_data(self, tmpdir, config):
1819+
tmpdir = pathlib.Path(tmpdir) / "lfw-py"
1820+
os.makedirs(tmpdir, exist_ok=True)
1821+
return dict(
1822+
num_examples=self._create_images_dir(tmpdir, self._IMAGES_DIR[config["image_set"]], config["split"]),
1823+
split=config["split"]
1824+
)
1825+
1826+
def _create_images_dir(self, root, idir, split):
1827+
idir = os.path.join(root, idir)
1828+
os.makedirs(idir, exist_ok=True)
1829+
n, flines = (10, ["10\n"]) if split == "10fold" else (1, [])
1830+
num_examples = 0
1831+
names = []
1832+
for _ in range(n):
1833+
num_people = random.randint(2, 5)
1834+
flines.append(f"{num_people}\n")
1835+
for i in range(num_people):
1836+
name = self._create_random_id()
1837+
no = random.randint(1, 10)
1838+
flines.append(f"{name}\t{no}\n")
1839+
names.append(f"{name}\t{no}\n")
1840+
datasets_utils.create_image_folder(idir, name, lambda n: f"{name}_{n+1:04d}.jpg", no, 250)
1841+
num_examples += no
1842+
with open(pathlib.Path(root) / f"people{self._file_id[split]}.txt", "w") as f:
1843+
f.writelines(flines)
1844+
with open(pathlib.Path(root) / "lfw-names.txt", "w") as f:
1845+
f.writelines(sorted(names))
1846+
1847+
return num_examples
1848+
1849+
def _create_random_id(self):
1850+
part1 = datasets_utils.create_random_string(random.randint(5, 7))
1851+
part2 = datasets_utils.create_random_string(random.randint(4, 7))
1852+
return f"{part1}_{part2}"
1853+
1854+
1855+
class LFWPairsTestCase(LFWPeopleTestCase):
1856+
DATASET_CLASS = datasets.LFWPairs
1857+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, int)
1858+
1859+
def _create_images_dir(self, root, idir, split):
1860+
idir = os.path.join(root, idir)
1861+
os.makedirs(idir, exist_ok=True)
1862+
num_pairs = 7 # effectively 7*2*n = 14*n
1863+
n, self.flines = (10, [f"10\t{num_pairs}"]) if split == "10fold" else (1, [str(num_pairs)])
1864+
for _ in range(n):
1865+
self._inject_pairs(idir, num_pairs, True)
1866+
self._inject_pairs(idir, num_pairs, False)
1867+
with open(pathlib.Path(root) / f"pairs{self._file_id[split]}.txt", "w") as f:
1868+
f.writelines(self.flines)
1869+
1870+
return num_pairs * 2 * n
1871+
1872+
def _inject_pairs(self, root, num_pairs, same):
1873+
for i in range(num_pairs):
1874+
name1 = self._create_random_id()
1875+
name2 = name1 if same else self._create_random_id()
1876+
no1, no2 = random.randint(1, 100), random.randint(1, 100)
1877+
if same:
1878+
self.flines.append(f"\n{name1}\t{no1}\t{no2}")
1879+
else:
1880+
self.flines.append(f"\n{name1}\t{no1}\t{name2}\t{no2}")
1881+
1882+
datasets_utils.create_image_folder(root, name1, lambda _: f"{name1}_{no1:04d}.jpg", 1, 250)
1883+
datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250)
1884+
1885+
18041886
if __name__ == "__main__":
18051887
unittest.main()

torchvision/datasets/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .places365 import Places365
2727
from .kitti import Kitti
2828
from .inaturalist import INaturalist
29+
from .lfw import LFWPeople, LFWPairs
2930

3031
__all__ = ('LSUN', 'LSUNClass',
3132
'ImageFolder', 'DatasetFolder', 'FakeData',
@@ -36,5 +37,5 @@
3637
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
3738
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
3839
'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101',
39-
'Places365', 'Kitti', "INaturalist"
40+
'Places365', 'Kitti', "INaturalist", "LFWPeople", "LFWPairs"
4041
)

torchvision/datasets/lfw.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import os
2+
from typing import Any, Callable, List, Optional, Tuple
3+
from PIL import Image
4+
from .vision import VisionDataset
5+
from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
6+
7+
8+
class _LFW(VisionDataset):
9+
10+
base_folder = 'lfw-py'
11+
download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
12+
13+
file_dict = {
14+
'original': ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
15+
'funneled': ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
16+
'deepfunneled': ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201")
17+
}
18+
checksums = {
19+
'pairs.txt': '9f1ba174e4e1c508ff7cdf10ac338a7d',
20+
'pairsDevTest.txt': '5132f7440eb68cf58910c8a45a2ac10b',
21+
'pairsDevTrain.txt': '4f27cbf15b2da4a85c1907eb4181ad21',
22+
'people.txt': '450f0863dd89e85e73936a6d71a3474b',
23+
'peopleDevTest.txt': 'e4bf5be0a43b5dcd9dc5ccfcb8fb19c5',
24+
'peopleDevTrain.txt': '54eaac34beb6d042ed3a7d883e247a21',
25+
'lfw-names.txt': 'a6d0a479bd074669f656265a6e693f6d'
26+
}
27+
annot_file = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'}
28+
names = "lfw-names.txt"
29+
30+
def __init__(
31+
self,
32+
root: str,
33+
split: str,
34+
image_set: str,
35+
view: str,
36+
transform: Optional[Callable] = None,
37+
target_transform: Optional[Callable] = None,
38+
download: bool = False,
39+
):
40+
super(_LFW, self).__init__(os.path.join(root, self.base_folder),
41+
transform=transform, target_transform=target_transform)
42+
43+
self.image_set = verify_str_arg(image_set.lower(), 'image_set', self.file_dict.keys())
44+
images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
45+
46+
self.view = verify_str_arg(view.lower(), 'view', ['people', 'pairs'])
47+
self.split = verify_str_arg(split.lower(), 'split', ['10fold', 'train', 'test'])
48+
self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
49+
self.data: List[Any] = []
50+
51+
if download:
52+
self.download()
53+
54+
if not self._check_integrity():
55+
raise RuntimeError('Dataset not found or corrupted.' +
56+
' You can use download=True to download it')
57+
58+
self.images_dir = os.path.join(self.root, images_dir)
59+
60+
def _loader(self, path: str) -> Image.Image:
61+
with open(path, 'rb') as f:
62+
img = Image.open(f)
63+
return img.convert('RGB')
64+
65+
def _check_integrity(self):
66+
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
67+
st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
68+
if not st1 or not st2:
69+
return False
70+
if self.view == "people":
71+
return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
72+
return True
73+
74+
def download(self):
75+
if self._check_integrity():
76+
print('Files already downloaded and verified')
77+
return
78+
url = f"{self.download_url_prefix}{self.filename}"
79+
download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
80+
download_url(f"{self.download_url_prefix}{self.labels_file}", self.root)
81+
if self.view == "people":
82+
download_url(f"{self.download_url_prefix}{self.names}", self.root)
83+
84+
def _get_path(self, identity, no):
85+
return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
86+
87+
def extra_repr(self) -> str:
88+
return f"Alignment: {self.image_set}\nSplit: {self.split}"
89+
90+
def __len__(self):
91+
return len(self.data)
92+
93+
94+
class LFWPeople(_LFW):
95+
"""`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
96+
97+
Args:
98+
root (string): Root directory of dataset where directory
99+
``lfw-py`` exists or will be saved to if download is set to True.
100+
split (string, optional): The image split to use. Can be one of ``train``, ``test``,
101+
``10fold`` (default).
102+
image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
103+
``deepfunneled``. Defaults to ``funneled``.
104+
transform (callable, optional): A function/transform that takes in an PIL image
105+
and returns a transformed version. E.g, ``transforms.RandomRotation``
106+
target_transform (callable, optional): A function/transform that takes in the
107+
target and transforms it.
108+
download (bool, optional): If true, downloads the dataset from the internet and
109+
puts it in root directory. If dataset is already downloaded, it is not
110+
downloaded again.
111+
112+
"""
113+
114+
def __init__(
115+
self,
116+
root: str,
117+
split: str = "10fold",
118+
image_set: str = "funneled",
119+
transform: Optional[Callable] = None,
120+
target_transform: Optional[Callable] = None,
121+
download: bool = False,
122+
):
123+
super(LFWPeople, self).__init__(root, split, image_set, "people",
124+
transform, target_transform, download)
125+
126+
self.class_to_idx = self._get_classes()
127+
self.data, self.targets = self._get_people()
128+
129+
def _get_people(self):
130+
data, targets = [], []
131+
with open(os.path.join(self.root, self.labels_file), 'r') as f:
132+
lines = f.readlines()
133+
n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
134+
135+
for fold in range(n_folds):
136+
n_lines = int(lines[s])
137+
people = [line.strip().split("\t") for line in lines[s + 1: s + n_lines + 1]]
138+
s += n_lines + 1
139+
for i, (identity, num_imgs) in enumerate(people):
140+
for num in range(1, int(num_imgs) + 1):
141+
img = self._get_path(identity, num)
142+
data.append(img)
143+
targets.append(self.class_to_idx[identity])
144+
145+
return data, targets
146+
147+
def _get_classes(self):
148+
with open(os.path.join(self.root, self.names), 'r') as f:
149+
lines = f.readlines()
150+
names = [line.strip().split()[0] for line in lines]
151+
class_to_idx = {name: i for i, name in enumerate(names)}
152+
return class_to_idx
153+
154+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
155+
"""
156+
Args:
157+
index (int): Index
158+
159+
Returns:
160+
tuple: Tuple (image, target) where target is the identity of the person.
161+
"""
162+
img = self._loader(self.data[index])
163+
target = self.targets[index]
164+
165+
if self.transform is not None:
166+
img = self.transform(img)
167+
168+
if self.target_transform is not None:
169+
target = self.target_transform(target)
170+
171+
return img, target
172+
173+
def extra_repr(self) -> str:
174+
return super().extra_repr() + "\nClasses (identities): {}".format(len(self.class_to_idx))
175+
176+
177+
class LFWPairs(_LFW):
178+
"""`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
179+
180+
Args:
181+
root (string): Root directory of dataset where directory
182+
``lfw-py`` exists or will be saved to if download is set to True.
183+
split (string, optional): The image split to use. Can be one of ``train``, ``test``,
184+
``10fold``. Defaults to ``10fold``.
185+
image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
186+
``deepfunneled``. Defaults to ``funneled``.
187+
transform (callable, optional): A function/transform that takes in an PIL image
188+
and returns a transformed version. E.g, ``transforms.RandomRotation``
189+
target_transform (callable, optional): A function/transform that takes in the
190+
target and transforms it.
191+
download (bool, optional): If true, downloads the dataset from the internet and
192+
puts it in root directory. If dataset is already downloaded, it is not
193+
downloaded again.
194+
195+
"""
196+
197+
def __init__(
198+
self,
199+
root: str,
200+
split: str = "10fold",
201+
image_set: str = "funneled",
202+
transform: Optional[Callable] = None,
203+
target_transform: Optional[Callable] = None,
204+
download: bool = False,
205+
):
206+
super(LFWPairs, self).__init__(root, split, image_set, "pairs",
207+
transform, target_transform, download)
208+
209+
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
210+
211+
def _get_pairs(self, images_dir):
212+
pair_names, data, targets = [], [], []
213+
with open(os.path.join(self.root, self.labels_file), 'r') as f:
214+
lines = f.readlines()
215+
if self.split == "10fold":
216+
n_folds, n_pairs = lines[0].split("\t")
217+
n_folds, n_pairs = int(n_folds), int(n_pairs)
218+
else:
219+
n_folds, n_pairs = 1, int(lines[0])
220+
s = 1
221+
222+
for fold in range(n_folds):
223+
matched_pairs = [line.strip().split("\t") for line in lines[s: s + n_pairs]]
224+
unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs: s + (2 * n_pairs)]]
225+
s += (2 * n_pairs)
226+
for pair in matched_pairs:
227+
img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
228+
pair_names.append((pair[0], pair[0]))
229+
data.append((img1, img2))
230+
targets.append(same)
231+
for pair in unmatched_pairs:
232+
img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0
233+
pair_names.append((pair[0], pair[2]))
234+
data.append((img1, img2))
235+
targets.append(same)
236+
237+
return pair_names, data, targets
238+
239+
def __getitem__(self, index: int) -> Tuple[Any, Any, int]:
240+
"""
241+
Args:
242+
index (int): Index
243+
244+
Returns:
245+
tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities.
246+
"""
247+
img1, img2 = self.data[index]
248+
img1, img2 = self._loader(img1), self._loader(img2)
249+
target = self.targets[index]
250+
251+
if self.transform is not None:
252+
img1, img2 = self.transform(img1), self.transform(img2)
253+
254+
if self.target_transform is not None:
255+
target = self.target_transform(target)
256+
257+
return img1, img2, target

0 commit comments

Comments
 (0)