|
| 1 | +import warnings |
| 2 | +from distutils.version import LooseVersion |
| 3 | +from typing import Callable, Optional, Sequence, Union |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce |
| 8 | + |
| 9 | +__all__ = [ |
| 10 | + "FID", |
| 11 | +] |
| 12 | + |
| 13 | + |
| 14 | +def fid_score( |
| 15 | + mu1: torch.Tensor, mu2: torch.Tensor, sigma1: torch.Tensor, sigma2: torch.Tensor, eps: float = 1e-6 |
| 16 | +) -> float: |
| 17 | + |
| 18 | + try: |
| 19 | + import numpy as np |
| 20 | + except ImportError: |
| 21 | + raise RuntimeError("fid_score requires numpy to be installed.") |
| 22 | + |
| 23 | + try: |
| 24 | + import scipy |
| 25 | + except ImportError: |
| 26 | + raise RuntimeError("fid_score requires scipy to be installed.") |
| 27 | + |
| 28 | + mu1, mu2 = mu1.cpu(), mu2.cpu() |
| 29 | + sigma1, sigma2 = sigma1.cpu(), sigma2.cpu() |
| 30 | + |
| 31 | + diff = mu1 - mu2 |
| 32 | + |
| 33 | + # Product might be almost singular |
| 34 | + covmean, _ = scipy.linalg.sqrtm(sigma1.mm(sigma2), disp=False) |
| 35 | + # Numerical error might give slight imaginary component |
| 36 | + if np.iscomplexobj(covmean): |
| 37 | + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): |
| 38 | + m = np.max(np.abs(covmean.imag)) |
| 39 | + raise ValueError("Imaginary component {}".format(m)) |
| 40 | + covmean = covmean.real |
| 41 | + |
| 42 | + tr_covmean = np.trace(covmean) |
| 43 | + |
| 44 | + if not np.isfinite(covmean).all(): |
| 45 | + tr_covmean = np.sum(np.sqrt(((np.diag(sigma1) * eps) * (np.diag(sigma2) * eps)) / (eps * eps))) |
| 46 | + |
| 47 | + return float(diff.dot(diff).item() + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean) |
| 48 | + |
| 49 | + |
| 50 | +class InceptionExtractor: |
| 51 | + def __init__(self) -> None: |
| 52 | + try: |
| 53 | + from torchvision import models |
| 54 | + except ImportError: |
| 55 | + raise RuntimeError("This module requires torchvision to be installed.") |
| 56 | + self.model = models.inception_v3(pretrained=True) |
| 57 | + self.model.fc = torch.nn.Identity() |
| 58 | + self.model.eval() |
| 59 | + |
| 60 | + @torch.no_grad() |
| 61 | + def __call__(self, data: torch.Tensor) -> torch.Tensor: |
| 62 | + if data.dim() != 4: |
| 63 | + raise ValueError(f"Inputs should be a tensor of dim 4, got {data.dim()}") |
| 64 | + if data.shape[1] != 3: |
| 65 | + raise ValueError(f"Inputs should be a tensor with 3 channels, got {data.shape}") |
| 66 | + return self.model(data) |
| 67 | + |
| 68 | + |
| 69 | +class FID(Metric): |
| 70 | + r"""Calculates Frechet Inception Distance. |
| 71 | +
|
| 72 | + .. math:: |
| 73 | + \text{FID} = |\mu_{1} - \mu_{2}| + \text{Tr}(\sigma_{1} + \sigma_{2} - {2}\sqrt{\sigma_1*\sigma_2}) |
| 74 | +
|
| 75 | + where :math:`\mu_1` and :math:`\sigma_1` refer to the mean and covariance of the train data and |
| 76 | + :math:`\mu_2` and :math:`\sigma_2` refer to the mean and covariance of the test data. |
| 77 | +
|
| 78 | + More details can be found in `Heusel et al. 2002`__ |
| 79 | +
|
| 80 | + __ https://arxiv.org/pdf/1706.08500.pdf |
| 81 | +
|
| 82 | + In addition, a faster and online computation approach can be found in `Chen et al. 2014`__ |
| 83 | +
|
| 84 | + __ https://arxiv.org/pdf/2009.14075.pdf |
| 85 | +
|
| 86 | + Remark: |
| 87 | +
|
| 88 | + This implementation is inspired by pytorch_fid package which can be found `here`__ |
| 89 | +
|
| 90 | + __ https://github.com/mseitzer/pytorch-fid |
| 91 | +
|
| 92 | + Args: |
| 93 | + num_features: number of features, must be defined if the parameter ``feature_extractor`` is also defined. |
| 94 | + Otherwise, default value is 2048. |
| 95 | + feature_extractor: a callable for extracting the features from the input data. If neither num_features nor |
| 96 | + feature_extractor are defined, default value is ``InceptionExtractor``. |
| 97 | + output_transform: a callable that is used to transform the |
| 98 | + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the |
| 99 | + form expected by the metric. This can be useful if, for example, you have a multi-output model and |
| 100 | + you want to compute the metric with respect to one of the outputs. |
| 101 | + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. |
| 102 | + device: specifies which device updates are accumulated on. Setting the |
| 103 | + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is |
| 104 | + non-blocking. By default, CPU. |
| 105 | +
|
| 106 | + Example: |
| 107 | +
|
| 108 | + .. code-block:: python |
| 109 | +
|
| 110 | + import torch |
| 111 | + from ignite.metric.gan import FID |
| 112 | +
|
| 113 | + y_pred, y = torch.rand(10, 2048), torch.rand(10, 2048) |
| 114 | + m = FID() |
| 115 | + m.update((y_pred, y)) |
| 116 | + print(m.compute()) |
| 117 | +
|
| 118 | + .. versionadded:: 0.5.0 |
| 119 | + """ |
| 120 | + |
| 121 | + def __init__( |
| 122 | + self, |
| 123 | + num_features: Optional[int] = None, |
| 124 | + feature_extractor: Optional[Callable] = None, |
| 125 | + output_transform: Callable = lambda x: x, |
| 126 | + device: Union[str, torch.device] = torch.device("cpu"), |
| 127 | + ) -> None: |
| 128 | + |
| 129 | + try: |
| 130 | + import numpy as np # noqa: F401 |
| 131 | + except ImportError: |
| 132 | + raise RuntimeError("This module requires numpy to be installed.") |
| 133 | + |
| 134 | + try: |
| 135 | + import scipy # noqa: F401 |
| 136 | + except ImportError: |
| 137 | + raise RuntimeError("This module requires scipy to be installed.") |
| 138 | + |
| 139 | + # default is inception |
| 140 | + if num_features is None and feature_extractor is None: |
| 141 | + num_features = 2048 |
| 142 | + feature_extractor = InceptionExtractor() |
| 143 | + elif num_features is None: |
| 144 | + raise ValueError("Argument num_features should be defined") |
| 145 | + elif feature_extractor is None: |
| 146 | + self._feature_extractor = lambda x: x |
| 147 | + feature_extractor = self._feature_extractor |
| 148 | + |
| 149 | + if num_features <= 0: |
| 150 | + raise ValueError(f"Argument num_features must be greater to zero, got: {num_features}") |
| 151 | + self._num_features = num_features |
| 152 | + self._feature_extractor = feature_extractor |
| 153 | + self._eps = 1e-6 |
| 154 | + super(FID, self).__init__(output_transform=output_transform, device=device) |
| 155 | + |
| 156 | + @staticmethod |
| 157 | + def _online_update(features: torch.Tensor, total: torch.Tensor, sigma: torch.Tensor) -> None: |
| 158 | + total += features |
| 159 | + if LooseVersion(torch.__version__) <= LooseVersion("1.7.0"): |
| 160 | + sigma += torch.ger(features, features) |
| 161 | + else: |
| 162 | + sigma += torch.outer(features, features) |
| 163 | + |
| 164 | + def _get_covariance(self, sigma: torch.Tensor, total: torch.Tensor) -> torch.Tensor: |
| 165 | + r""" |
| 166 | + Calculates covariance from mean and sum of products of variables |
| 167 | + """ |
| 168 | + sub_matrix = torch.outer(total, total) |
| 169 | + sub_matrix = sub_matrix / self._num_examples |
| 170 | + return (sigma - sub_matrix) / (self._num_examples - 1) |
| 171 | + |
| 172 | + @staticmethod |
| 173 | + def _check_feature_input(train: torch.Tensor, test: torch.Tensor) -> None: |
| 174 | + for feature in [train, test]: |
| 175 | + if feature.dim() != 2: |
| 176 | + raise ValueError(f"Features must be a tensor of dim 2, got: {feature.dim()}") |
| 177 | + if feature.shape[0] == 0: |
| 178 | + raise ValueError(f"Batch size should be greater than one, got: {feature.shape[0]}") |
| 179 | + if feature.shape[1] == 0: |
| 180 | + raise ValueError(f"Feature size should be greater than one, got: {feature.shape[1]}") |
| 181 | + if train.shape[0] != test.shape[0] or train.shape[1] != test.shape[1]: |
| 182 | + raise ValueError( |
| 183 | + f"Number of Training Features and Testing Features should be equal ({train.shape} != {test.shape})" |
| 184 | + ) |
| 185 | + |
| 186 | + @reinit__is_reduced |
| 187 | + def reset(self) -> None: |
| 188 | + self._train_sigma = torch.zeros((self._num_features, self._num_features), dtype=torch.float64).to(self._device) |
| 189 | + self._train_total = torch.zeros(self._num_features, dtype=torch.float64).to(self._device) |
| 190 | + self._test_sigma = torch.zeros((self._num_features, self._num_features), dtype=torch.float64).to(self._device) |
| 191 | + self._test_total = torch.zeros(self._num_features, dtype=torch.float64).to(self._device) |
| 192 | + self._num_examples = 0 |
| 193 | + super(FID, self).reset() |
| 194 | + |
| 195 | + @reinit__is_reduced |
| 196 | + def update(self, output: Sequence[torch.Tensor]) -> None: |
| 197 | + |
| 198 | + # Extract the features from the outputs |
| 199 | + train_features = self._feature_extractor(output[0].detach()).to(self._device) |
| 200 | + test_features = self._feature_extractor(output[1].detach()).to(self._device) |
| 201 | + |
| 202 | + # Check the feature shapess |
| 203 | + self._check_feature_input(train_features, test_features) |
| 204 | + |
| 205 | + # Updates the mean and covariance for the train features |
| 206 | + for i, features in enumerate(train_features, start=self._num_examples + 1): |
| 207 | + self._online_update(features, self._train_total, self._train_sigma) |
| 208 | + |
| 209 | + # Updates the mean and covariance for the test features |
| 210 | + for i, features in enumerate(test_features, start=self._num_examples + 1): |
| 211 | + self._online_update(features, self._test_total, self._test_sigma) |
| 212 | + |
| 213 | + self._num_examples += train_features.shape[0] |
| 214 | + |
| 215 | + @sync_all_reduce("_num_examples", "_train_total", "_test_total", "_train_sigma", "_test_sigma") |
| 216 | + def compute(self) -> float: |
| 217 | + fid = fid_score( |
| 218 | + mu1=self._train_total / self._num_examples, |
| 219 | + mu2=self._test_total / self._num_examples, |
| 220 | + sigma1=self._get_covariance(self._train_sigma, self._train_total), |
| 221 | + sigma2=self._get_covariance(self._test_sigma, self._test_total), |
| 222 | + eps=self._eps, |
| 223 | + ) |
| 224 | + if torch.isnan(torch.tensor(fid)) or torch.isinf(torch.tensor(fid)): |
| 225 | + warnings.warn("The product of covariance of train and test features is out of bounds.") |
| 226 | + return fid |
0 commit comments