Skip to content

Commit 6f4a488

Browse files
rohitgr7Borda
andauthored
Add functional regression metrics (#2492)
* Add functional regression metrics * add functional tests * add docs * changelog * init * pep8 * docs * docs * setup docs * docs * Apply suggestions from code review * Apply suggestions from code review * typo Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka <[email protected]>
1 parent 4bbcfa0 commit 6f4a488

File tree

11 files changed

+380
-184
lines changed

11 files changed

+380
-184
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added a PSNR metric: peak signal-to-noise ratio ([#2483](https://github.com/PyTorchLightning/pytorch-lightning/pull/2483))
1313

14+
- Added functional regression metrics ([#2492](https://github.com/PyTorchLightning/pytorch-lightning/pull/2492))
15+
1416
### Changed
1517

1618

docs/source/introduction_guide.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ quickly becomes onerous and starts distracting from the core research code.
2222
Goal of this guide
2323
------------------
2424
This guide walks through the major parts of the library to help you understand
25-
what each parts does. But at the end of the day, you write the same PyTorch code... just organize it
25+
what each part does. But at the end of the day, you write the same PyTorch code... just organize it
2626
into the LightningModule template which means you keep ALL the flexibility without having to deal with
2727
any of the boilerplate code
2828

docs/source/metrics.rst

+30
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,36 @@ iou (F)
361361
.. autofunction:: pytorch_lightning.metrics.functional.iou
362362
:noindex:
363363

364+
mse (F)
365+
^^^^^^^
366+
367+
.. autofunction:: pytorch_lightning.metrics.functional.mse
368+
:noindex:
369+
370+
rmse (F)
371+
^^^^^^^^
372+
373+
.. autofunction:: pytorch_lightning.metrics.functional.rmse
374+
:noindex:
375+
376+
mae (F)
377+
^^^^^^^
378+
379+
.. autofunction:: pytorch_lightning.metrics.functional.mae
380+
:noindex:
381+
382+
rmsle (F)
383+
^^^^^^^^^
384+
385+
.. autofunction:: pytorch_lightning.metrics.functional.rmsle
386+
:noindex:
387+
388+
psnr (F)
389+
^^^^^^^^
390+
391+
.. autofunction:: pytorch_lightning.metrics.functional.psnr
392+
:noindex:
393+
364394
stat_scores_multiple_classes (F)
365395
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
366396

pytorch_lightning/core/hooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def prepare_data(self):
3939
# don't do this
4040
self.something = else
4141
42-
def setup(step):
42+
def setup(stage):
4343
data = Load_data(...)
4444
self.l1 = nn.Linear(28, data.num_classes)
4545

pytorch_lightning/core/lightning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1321,7 +1321,7 @@ def prepare_data(self):
13211321
13221322
model.prepare_data()
13231323
if ddp/tpu: init()
1324-
model.setup(step)
1324+
model.setup(stage)
13251325
model.train_dataloader()
13261326
model.val_dataloader()
13271327
model.test_dataloader()

pytorch_lightning/metrics/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
22
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
33
from pytorch_lightning.metrics.regression import (
4+
MAE,
45
MSE,
6+
PSNR,
57
RMSE,
6-
MAE,
7-
RMSLE,
8-
PSNR
8+
RMSLE
99
)
1010
from pytorch_lightning.metrics.classification import (
1111
Accuracy,
@@ -48,10 +48,10 @@
4848
'IoU',
4949
]
5050
__regression_metrics = [
51+
'MAE',
5152
'MSE',
53+
'PSNR',
5254
'RMSE',
53-
'MAE',
54-
'RMSLE',
55-
'PSNR'
55+
'RMSLE'
5656
]
5757
__all__ = __regression_metrics + __classification_metrics + ['SklearnMetric']

pytorch_lightning/metrics/functional/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@
2020
to_onehot,
2121
iou,
2222
)
23+
from pytorch_lightning.metrics.functional.regression import (
24+
mae,
25+
mse,
26+
psnr,
27+
rmse,
28+
rmsle
29+
)

pytorch_lightning/metrics/functional/regression.py

+149-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,139 @@
44
from pytorch_lightning.metrics.functional.reduction import reduce
55

66

7+
def mse(
8+
pred: torch.Tensor,
9+
target: torch.Tensor,
10+
reduction: str = 'elementwise_mean'
11+
) -> torch.Tensor:
12+
"""
13+
Computes mean squared error
14+
15+
Args:
16+
pred: estimated labels
17+
target: ground truth labels
18+
reduction: method for reducing mse (default: takes the mean)
19+
Available reduction methods:
20+
21+
- elementwise_mean: takes the mean
22+
- none: pass array
23+
- sum: add elements
24+
25+
Return:
26+
Tensor with MSE
27+
28+
Example:
29+
30+
>>> x = torch.tensor([0., 1, 2, 3])
31+
>>> y = torch.tensor([0., 1, 2, 2])
32+
>>> mse(x, y)
33+
tensor(0.2500)
34+
35+
"""
36+
mse = F.mse_loss(pred, target, reduction='none')
37+
mse = reduce(mse, reduction=reduction)
38+
return mse
39+
40+
41+
def rmse(
42+
pred: torch.Tensor,
43+
target: torch.Tensor,
44+
reduction: str = 'elementwise_mean'
45+
) -> torch.Tensor:
46+
"""
47+
Computes root mean squared error
48+
49+
Args:
50+
pred: estimated labels
51+
target: ground truth labels
52+
reduction: method for reducing rmse (default: takes the mean)
53+
Available reduction methods:
54+
55+
- elementwise_mean: takes the mean
56+
- none: pass array
57+
- sum: add elements
58+
59+
Return:
60+
Tensor with RMSE
61+
62+
63+
>>> x = torch.tensor([0., 1, 2, 3])
64+
>>> y = torch.tensor([0., 1, 2, 2])
65+
>>> rmse(x, y)
66+
tensor(0.5000)
67+
68+
"""
69+
rmse = torch.sqrt(mse(pred, target, reduction=reduction))
70+
return rmse
71+
72+
73+
def mae(
74+
pred: torch.Tensor,
75+
target: torch.Tensor,
76+
reduction: str = 'elementwise_mean'
77+
) -> torch.Tensor:
78+
"""
79+
Computes mean absolute error
80+
81+
Args:
82+
pred: estimated labels
83+
target: ground truth labels
84+
reduction: method for reducing mae (default: takes the mean)
85+
Available reduction methods:
86+
87+
- elementwise_mean: takes the mean
88+
- none: pass array
89+
- sum: add elements
90+
91+
Return:
92+
Tensor with MAE
93+
94+
Example:
95+
96+
>>> x = torch.tensor([0., 1, 2, 3])
97+
>>> y = torch.tensor([0., 1, 2, 2])
98+
>>> mae(x, y)
99+
tensor(0.2500)
100+
101+
"""
102+
mae = F.l1_loss(pred, target, reduction='none')
103+
mae = reduce(mae, reduction=reduction)
104+
return mae
105+
106+
107+
def rmsle(
108+
pred: torch.Tensor,
109+
target: torch.Tensor,
110+
reduction: str = 'elementwise_mean'
111+
) -> torch.Tensor:
112+
"""
113+
Computes root mean squared log error
114+
115+
Args:
116+
pred: estimated labels
117+
target: ground truth labels
118+
reduction: method for reducing rmsle (default: takes the mean)
119+
Available reduction methods:
120+
121+
- elementwise_mean: takes the mean
122+
- none: pass array
123+
- sum: add elements
124+
125+
Return:
126+
Tensor with RMSLE
127+
128+
Example:
129+
130+
>>> x = torch.tensor([0., 1, 2, 3])
131+
>>> y = torch.tensor([0., 1, 2, 2])
132+
>>> rmsle(x, y)
133+
tensor(0.0207)
134+
135+
"""
136+
rmsle = mse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction)
137+
return rmsle
138+
139+
7140
def psnr(
8141
pred: torch.Tensor,
9142
target: torch.Tensor,
@@ -12,14 +145,22 @@ def psnr(
12145
reduction: str = 'elementwise_mean'
13146
) -> torch.Tensor:
14147
"""
15-
Computes the peak signal-to-noise ratio metric
148+
Computes the peak signal-to-noise ratio
16149
17150
Args:
18151
pred: estimated signal
19152
target: groun truth signal
20-
data_range: the range of the data. If None, it is determined from the data (max - min).
153+
data_range: the range of the data. If None, it is determined from the data (max - min)
21154
base: a base of a logarithm to use (default: 10)
22155
reduction: method for reducing psnr (default: takes the mean)
156+
Available reduction methods:
157+
158+
- elementwise_mean: takes the mean
159+
- none: pass array
160+
- sum add elements
161+
162+
Return:
163+
Tensor with PSNR score
23164
24165
Example:
25166
@@ -29,12 +170,15 @@ def psnr(
29170
>>> metric = PSNR()
30171
>>> metric(pred, target)
31172
tensor(2.5527)
173+
32174
"""
33175

34176
if data_range is None:
35177
data_range = max(target.max() - target.min(), pred.max() - pred.min())
36178
else:
37179
data_range = torch.tensor(float(data_range))
38-
mse = F.mse_loss(pred.view(-1), target.view(-1), reduction=reduction)
39-
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse)
40-
return psnr_base_e * (10 / torch.log(torch.tensor(base)))
180+
181+
mse_score = mse(pred.view(-1), target.view(-1), reduction=reduction)
182+
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score)
183+
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
184+
return psnr

0 commit comments

Comments
 (0)