Skip to content

Commit 4cbc0fa

Browse files
guciferDesrozierssdesrozisvfdev-5
authored
Fid Metric (pytorch#2049)
* FID metric * improve the default values * autopep8 fix * Format fix * Docs update * Import tests * return type fix * Update tests/ignite/metrics/gan/test_fid.py Co-authored-by: Sylvain Desroziers <[email protected]> * Fixed test * Dummy Inception Class for testing * Added Inheritance * Added test init file * Added new tests * Fixed mypy errors * Added edge case for infinite * Used standard limit variables * Used standard limit variables * Improved user messages * Added ger for previous torch versions * LooseVersion * Docs update * Warning and Formula change * Made get_covariance private * Mypy fix * Test fix * Update ignite/metrics/gan/fid.py * autopep8 fix * Fixed Docs * Trace change * Float type output * Test fix * Convert everything to pytorch * Numpy complex check * Numpy as a dependency Co-authored-by: Desroziers <[email protected]> Co-authored-by: sdesrozis <[email protected]> Co-authored-by: Sylvain Desroziers <[email protected]> Co-authored-by: vfdev <[email protected]> Co-authored-by: vfdev-5 <[email protected]>
1 parent 6687900 commit 4cbc0fa

File tree

8 files changed

+502
-0
lines changed

8 files changed

+502
-0
lines changed

docs/source/metrics.rst

+1
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ Complete list of metrics
334334
Rouge
335335
RougeL
336336
RougeN
337+
FID
337338

338339
Helpers for customizing metrics
339340
-------------------------------

ignite/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ignite.metrics.epoch_metric import EpochMetric
66
from ignite.metrics.fbeta import Fbeta
77
from ignite.metrics.frequency import Frequency
8+
from ignite.metrics.gan.fid import FID
89
from ignite.metrics.loss import Loss
910
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
1011
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
@@ -37,6 +38,7 @@
3738
"DiceCoefficient",
3839
"EpochMetric",
3940
"Fbeta",
41+
"FID",
4042
"GeometricAverage",
4143
"IoU",
4244
"mIoU",

ignite/metrics/gan/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ignite.metrics.gan.fid import FID
2+
3+
__all__ = [
4+
"FID",
5+
]

ignite/metrics/gan/fid.py

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import warnings
2+
from distutils.version import LooseVersion
3+
from typing import Callable, Optional, Sequence, Union
4+
5+
import torch
6+
7+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
8+
9+
__all__ = [
10+
"FID",
11+
]
12+
13+
14+
def fid_score(
15+
mu1: torch.Tensor, mu2: torch.Tensor, sigma1: torch.Tensor, sigma2: torch.Tensor, eps: float = 1e-6
16+
) -> float:
17+
18+
try:
19+
import numpy as np
20+
except ImportError:
21+
raise RuntimeError("fid_score requires numpy to be installed.")
22+
23+
try:
24+
import scipy
25+
except ImportError:
26+
raise RuntimeError("fid_score requires scipy to be installed.")
27+
28+
mu1, mu2 = mu1.cpu(), mu2.cpu()
29+
sigma1, sigma2 = sigma1.cpu(), sigma2.cpu()
30+
31+
diff = mu1 - mu2
32+
33+
# Product might be almost singular
34+
covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False)
35+
# Numerical error might give slight imaginary component
36+
if np.iscomplexobj(covmean):
37+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
38+
m = np.max(np.abs(covmean.imag))
39+
raise ValueError("Imaginary component {}".format(m))
40+
covmean = covmean.real
41+
42+
tr_covmean = np.trace(covmean)
43+
44+
if not np.isfinite(covmean).all():
45+
tr_covmean = np.sum(np.sqrt(((np.diag(sigma1) * eps) * (np.diag(sigma2) * eps)) / (eps * eps)))
46+
47+
return float(diff.dot(diff).item() + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean)
48+
49+
50+
class InceptionExtractor:
51+
def __init__(self) -> None:
52+
try:
53+
from torchvision import models
54+
except ImportError:
55+
raise RuntimeError("This module requires torchvision to be installed.")
56+
self.model = models.inception_v3(pretrained=True)
57+
self.model.fc = torch.nn.Identity()
58+
self.model.eval()
59+
60+
@torch.no_grad()
61+
def __call__(self, data: torch.Tensor) -> torch.Tensor:
62+
if data.dim() != 4:
63+
raise ValueError(f"Inputs should be a tensor of dim 4, got {data.dim()}")
64+
if data.shape[1] != 3:
65+
raise ValueError(f"Inputs should be a tensor with 3 channels, got {data.shape}")
66+
return self.model(data)
67+
68+
69+
class FID(Metric):
70+
r"""Calculates Frechet Inception Distance.
71+
72+
.. math::
73+
\text{FID} = |\mu_{1} - \mu_{2}| + \text{Tr}(\sigma_{1} + \sigma_{2} - {2}\sqrt{\sigma_1*\sigma_2})
74+
75+
where :math:`\mu_1` and :math:`\sigma_1` refer to the mean and covariance of the train data and
76+
:math:`\mu_2` and :math:`\sigma_2` refer to the mean and covariance of the test data.
77+
78+
More details can be found in `Heusel et al. 2002`__
79+
80+
__ https://arxiv.org/pdf/1706.08500.pdf
81+
82+
In addition, a faster and online computation approach can be found in `Chen et al. 2014`__
83+
84+
__ https://arxiv.org/pdf/2009.14075.pdf
85+
86+
Remark:
87+
88+
This implementation is inspired by pytorch_fid package which can be found `here`__
89+
90+
__ https://github.com/mseitzer/pytorch-fid
91+
92+
Args:
93+
num_features: number of features, must be defined if the parameter ``feature_extractor`` is also defined.
94+
Otherwise, default value is 2048.
95+
feature_extractor: a callable for extracting the features from the input data. If neither num_features nor
96+
feature_extractor are defined, default value is ``InceptionExtractor``.
97+
output_transform: a callable that is used to transform the
98+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
99+
form expected by the metric. This can be useful if, for example, you have a multi-output model and
100+
you want to compute the metric with respect to one of the outputs.
101+
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
102+
device: specifies which device updates are accumulated on. Setting the
103+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
104+
non-blocking. By default, CPU.
105+
106+
Example:
107+
108+
.. code-block:: python
109+
110+
import torch
111+
from ignite.metric.gan import FID
112+
113+
y_pred, y = torch.rand(10, 2048), torch.rand(10, 2048)
114+
m = FID()
115+
m.update((y_pred, y))
116+
print(m.compute())
117+
118+
.. versionadded:: 0.5.0
119+
"""
120+
121+
def __init__(
122+
self,
123+
num_features: Optional[int] = None,
124+
feature_extractor: Optional[Callable] = None,
125+
output_transform: Callable = lambda x: x,
126+
device: Union[str, torch.device] = torch.device("cpu"),
127+
) -> None:
128+
129+
try:
130+
import numpy as np # noqa: F401
131+
except ImportError:
132+
raise RuntimeError("This module requires numpy to be installed.")
133+
134+
try:
135+
import scipy # noqa: F401
136+
except ImportError:
137+
raise RuntimeError("This module requires scipy to be installed.")
138+
139+
# default is inception
140+
if num_features is None and feature_extractor is None:
141+
num_features = 2048
142+
feature_extractor = InceptionExtractor()
143+
elif num_features is None:
144+
raise ValueError("Argument num_features should be defined")
145+
elif feature_extractor is None:
146+
self._feature_extractor = lambda x: x
147+
feature_extractor = self._feature_extractor
148+
149+
if num_features <= 0:
150+
raise ValueError(f"Argument num_features must be greater to zero, got: {num_features}")
151+
self._num_features = num_features
152+
self._feature_extractor = feature_extractor
153+
self._eps = 1e-6
154+
super(FID, self).__init__(output_transform=output_transform, device=device)
155+
156+
@staticmethod
157+
def _online_update(features: torch.Tensor, total: torch.Tensor, sigma: torch.Tensor) -> None:
158+
total += features
159+
if LooseVersion(torch.__version__) <= LooseVersion("1.7.0"):
160+
sigma += torch.ger(features, features)
161+
else:
162+
sigma += torch.outer(features, features)
163+
164+
def _get_covariance(self, sigma: torch.Tensor, total: torch.Tensor) -> torch.Tensor:
165+
r"""
166+
Calculates covariance from mean and sum of products of variables
167+
"""
168+
sub_matrix = torch.outer(total, total)
169+
sub_matrix = sub_matrix / self._num_examples
170+
return (sigma - sub_matrix) / (self._num_examples - 1)
171+
172+
@staticmethod
173+
def _check_feature_input(train: torch.Tensor, test: torch.Tensor) -> None:
174+
for feature in [train, test]:
175+
if feature.dim() != 2:
176+
raise ValueError(f"Features must be a tensor of dim 2, got: {feature.dim()}")
177+
if feature.shape[0] == 0:
178+
raise ValueError(f"Batch size should be greater than one, got: {feature.shape[0]}")
179+
if feature.shape[1] == 0:
180+
raise ValueError(f"Feature size should be greater than one, got: {feature.shape[1]}")
181+
if train.shape[0] != test.shape[0] or train.shape[1] != test.shape[1]:
182+
raise ValueError(
183+
f"Number of Training Features and Testing Features should be equal ({train.shape} != {test.shape})"
184+
)
185+
186+
@reinit__is_reduced
187+
def reset(self) -> None:
188+
self._train_sigma = torch.zeros((self._num_features, self._num_features), dtype=torch.float64).to(self._device)
189+
self._train_total = torch.zeros(self._num_features, dtype=torch.float64).to(self._device)
190+
self._test_sigma = torch.zeros((self._num_features, self._num_features), dtype=torch.float64).to(self._device)
191+
self._test_total = torch.zeros(self._num_features, dtype=torch.float64).to(self._device)
192+
self._num_examples = 0
193+
super(FID, self).reset()
194+
195+
@reinit__is_reduced
196+
def update(self, output: Sequence[torch.Tensor]) -> None:
197+
198+
# Extract the features from the outputs
199+
train_features = self._feature_extractor(output[0].detach()).to(self._device)
200+
test_features = self._feature_extractor(output[1].detach()).to(self._device)
201+
202+
# Check the feature shapess
203+
self._check_feature_input(train_features, test_features)
204+
205+
# Updates the mean and covariance for the train features
206+
for i, features in enumerate(train_features, start=self._num_examples + 1):
207+
self._online_update(features, self._train_total, self._train_sigma)
208+
209+
# Updates the mean and covariance for the test features
210+
for i, features in enumerate(test_features, start=self._num_examples + 1):
211+
self._online_update(features, self._test_total, self._test_sigma)
212+
213+
self._num_examples += train_features.shape[0]
214+
215+
@sync_all_reduce("_num_examples", "_train_total", "_test_total", "_train_sigma", "_test_sigma")
216+
def compute(self) -> float:
217+
fid = fid_score(
218+
mu1=self._train_total / self._num_examples,
219+
mu2=self._test_total / self._num_examples,
220+
sigma1=self._get_covariance(self._train_sigma, self._train_total),
221+
sigma2=self._get_covariance(self._test_sigma, self._test_total),
222+
eps=self._eps,
223+
)
224+
if torch.isnan(torch.tensor(fid)) or torch.isinf(torch.tensor(fid)):
225+
warnings.warn("The product of covariance of train and test features is out of bounds.")
226+
return fid

mypy.ini

+6
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,9 @@ ignore_missing_imports = True
7171

7272
[mypy-tqdm.*]
7373
ignore_missing_imports = True
74+
75+
[mypy-scipy.*]
76+
ignore_missing_imports = True
77+
78+
[mypy-torchvision.*]
79+
ignore_missing_imports = True

requirements-dev.txt

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ pytest-cov
66
pytest-xdist
77
dill
88
# Test contrib dependencies
9+
scipy
10+
pytorch_fid
911
tqdm
1012
scikit-learn
1113
matplotlib
@@ -17,6 +19,7 @@ wandb
1719
mlflow
1820
neptune-client
1921
tensorboard
22+
torchvision
2023
pynvml
2124
clearml
2225
scikit-image

tests/ignite/metrics/gan/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)